diff --git a/Modules/_billiard/clinic/posixshmem.c.h b/Modules/_billiard/clinic/posixshmem.c.h new file mode 100644 index 00000000..3424b10a --- /dev/null +++ b/Modules/_billiard/clinic/posixshmem.c.h @@ -0,0 +1,123 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +#if defined(HAVE_SHM_OPEN) + +PyDoc_STRVAR(_posixshmem_shm_open__doc__, +"shm_open($module, /, path, flags, mode=511)\n" +"--\n" +"\n" +"Open a shared memory object. Returns a file descriptor (integer)."); + +#define _POSIXSHMEM_SHM_OPEN_METHODDEF \ + {"shm_open", (PyCFunction)(void(*)(void))_posixshmem_shm_open, METH_FASTCALL|METH_KEYWORDS, _posixshmem_shm_open__doc__}, + +static int +_posixshmem_shm_open_impl(PyObject *module, PyObject *path, int flags, + int mode); + +static PyObject * +_posixshmem_shm_open(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + static const char * const _keywords[] = {"path", "flags", "mode", NULL}; + static _PyArg_Parser _parser = {NULL, _keywords, "shm_open", 0}; + PyObject *argsbuf[3]; + Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 2; + PyObject *path; + int flags; + int mode = 511; + int _return_value; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 2, 3, 0, argsbuf); + if (!args) { + goto exit; + } + if (!PyUnicode_Check(args[0])) { + _PyArg_BadArgument("shm_open", "argument 'path'", "str", args[0]); + goto exit; + } + if (PyUnicode_READY(args[0]) == -1) { + goto exit; + } + path = args[0]; + flags = _PyLong_AsInt(args[1]); + if (flags == -1 && PyErr_Occurred()) { + goto exit; + } + if (!noptargs) { + goto skip_optional_pos; + } + mode = _PyLong_AsInt(args[2]); + if (mode == -1 && PyErr_Occurred()) { + goto exit; + } +skip_optional_pos: + _return_value = _posixshmem_shm_open_impl(module, path, flags, mode); + if ((_return_value == -1) && PyErr_Occurred()) { + goto exit; + } + return_value = PyLong_FromLong((long)_return_value); + +exit: + return return_value; +} + +#endif /* defined(HAVE_SHM_OPEN) */ + +#if defined(HAVE_SHM_UNLINK) + +PyDoc_STRVAR(_posixshmem_shm_unlink__doc__, +"shm_unlink($module, /, path)\n" +"--\n" +"\n" +"Remove a shared memory object (similar to unlink()).\n" +"\n" +"Remove a shared memory object name, and, once all processes have unmapped\n" +"the object, de-allocates and destroys the contents of the associated memory\n" +"region."); + +#define _POSIXSHMEM_SHM_UNLINK_METHODDEF \ + {"shm_unlink", (PyCFunction)(void(*)(void))_posixshmem_shm_unlink, METH_FASTCALL|METH_KEYWORDS, _posixshmem_shm_unlink__doc__}, + +static PyObject * +_posixshmem_shm_unlink_impl(PyObject *module, PyObject *path); + +static PyObject * +_posixshmem_shm_unlink(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + static const char * const _keywords[] = {"path", NULL}; + static _PyArg_Parser _parser = {NULL, _keywords, "shm_unlink", 0}; + PyObject *argsbuf[1]; + PyObject *path; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 1, 0, argsbuf); + if (!args) { + goto exit; + } + if (!PyUnicode_Check(args[0])) { + _PyArg_BadArgument("shm_unlink", "argument 'path'", "str", args[0]); + goto exit; + } + if (PyUnicode_READY(args[0]) == -1) { + goto exit; + } + path = args[0]; + return_value = _posixshmem_shm_unlink_impl(module, path); + +exit: + return return_value; +} + +#endif /* defined(HAVE_SHM_UNLINK) */ + +#ifndef _POSIXSHMEM_SHM_OPEN_METHODDEF + #define _POSIXSHMEM_SHM_OPEN_METHODDEF +#endif /* !defined(_POSIXSHMEM_SHM_OPEN_METHODDEF) */ + +#ifndef _POSIXSHMEM_SHM_UNLINK_METHODDEF + #define _POSIXSHMEM_SHM_UNLINK_METHODDEF +#endif /* !defined(_POSIXSHMEM_SHM_UNLINK_METHODDEF) */ +/*[clinic end generated code: output=bca8e78d0f43ef1a input=a9049054013a1b77]*/ diff --git a/Modules/_billiard/multiprocessing.c b/Modules/_billiard/multiprocessing.c index 0d5af9d5..806e6383 100644 --- a/Modules/_billiard/multiprocessing.c +++ b/Modules/_billiard/multiprocessing.c @@ -9,22 +9,15 @@ #include "multiprocessing.h" -#ifdef SCM_RIGHTS - #define HAVE_FD_TRANSFER 1 -#else - #define HAVE_FD_TRANSFER 0 -#endif /* * Function which raises exceptions based on error codes */ PyObject * -Billiard_SetError(PyObject *Type, int num) +_PyMp_SetError(PyObject *Type, int num) { switch (num) { - case MP_SUCCESS: - break; #ifdef MS_WINDOWS case MP_STANDARD_ERROR: if (Type == NULL) @@ -47,16 +40,6 @@ Billiard_SetError(PyObject *Type, int num) case MP_MEMORY_ERROR: PyErr_NoMemory(); break; - case MP_END_OF_FILE: - PyErr_SetNone(PyExc_EOFError); - break; - case MP_EARLY_END_OF_FILE: - PyErr_SetString(PyExc_IOError, - "got end of file during message"); - break; - case MP_BAD_MESSAGE_LENGTH: - PyErr_SetString(PyExc_IOError, "bad message length"); - break; case MP_EXCEPTION_HAS_BEEN_SET: break; default: @@ -66,27 +49,9 @@ Billiard_SetError(PyObject *Type, int num) return NULL; } - -/* - * Windows only - */ - #ifdef MS_WINDOWS - -/* On Windows we set an event to signal Ctrl-C; compare with timemodule.c */ - -HANDLE sigint_event = NULL; - -static BOOL WINAPI -ProcessingCtrlHandler(DWORD dwCtrlType) -{ - SetEvent(sigint_event); - return FALSE; -} - - static PyObject * -Billiard_closesocket(PyObject *self, PyObject *args) +multiprocessing_closesocket(PyObject *self, PyObject *args) { HANDLE handle; int ret; @@ -99,12 +64,12 @@ Billiard_closesocket(PyObject *self, PyObject *args) Py_END_ALLOW_THREADS if (ret) - return PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + return PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError()); Py_RETURN_NONE; } static PyObject * -Billiard_recv(PyObject *self, PyObject *args) +multiprocessing_recv(PyObject *self, PyObject *args) { HANDLE handle; int size, nread; @@ -123,14 +88,14 @@ Billiard_recv(PyObject *self, PyObject *args) if (nread < 0) { Py_DECREF(buf); - return PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + return PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError()); } _PyBytes_Resize(&buf, nread); return buf; } static PyObject * -Billiard_send(PyObject *self, PyObject *args) +multiprocessing_send(PyObject *self, PyObject *args) { HANDLE handle; Py_buffer buf; @@ -147,205 +112,24 @@ Billiard_send(PyObject *self, PyObject *args) PyBuffer_Release(&buf); if (ret < 0) - return PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + return PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError()); return PyLong_FromLong(ret); } - -/* - * Unix only - */ - -#else /* !MS_WINDOWS */ - -#if HAVE_FD_TRANSFER - -/* Functions for transferring file descriptors between processes. - Reimplements some of the functionality of the fdcred - module at http://www.mca-ltd.com/resources/fdcred_1.tgz. */ - -static PyObject * -Billiard_multiprocessing_sendfd(PyObject *self, PyObject *args) -{ - int conn, fd, res; - char dummy_char; - char buf[CMSG_SPACE(sizeof(int))]; - struct msghdr msg = {0}; - struct iovec dummy_iov; - struct cmsghdr *cmsg; - - if (!PyArg_ParseTuple(args, "ii", &conn, &fd)) - return NULL; - - dummy_iov.iov_base = &dummy_char; - dummy_iov.iov_len = 1; - msg.msg_control = buf; - msg.msg_controllen = sizeof(buf); - msg.msg_iov = &dummy_iov; - msg.msg_iovlen = 1; - cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - msg.msg_controllen = cmsg->cmsg_len; - *(int*)CMSG_DATA(cmsg) = fd; - - Py_BEGIN_ALLOW_THREADS - res = sendmsg(conn, &msg, 0); - Py_END_ALLOW_THREADS - - if (res < 0) - return PyErr_SetFromErrno(PyExc_OSError); - Py_RETURN_NONE; -} - -static PyObject * -Billiard_multiprocessing_recvfd(PyObject *self, PyObject *args) -{ - int conn, fd, res; - char dummy_char; - char buf[CMSG_SPACE(sizeof(int))]; - struct msghdr msg = {0}; - struct iovec dummy_iov; - struct cmsghdr *cmsg; - - if (!PyArg_ParseTuple(args, "i", &conn)) - return NULL; - - dummy_iov.iov_base = &dummy_char; - dummy_iov.iov_len = 1; - msg.msg_control = buf; - msg.msg_controllen = sizeof(buf); - msg.msg_iov = &dummy_iov; - msg.msg_iovlen = 1; - cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - msg.msg_controllen = cmsg->cmsg_len; - - Py_BEGIN_ALLOW_THREADS - res = recvmsg(conn, &msg, 0); - Py_END_ALLOW_THREADS - - if (res < 0) - return PyErr_SetFromErrno(PyExc_OSError); - - fd = *(int*)CMSG_DATA(cmsg); - return Py_BuildValue("i", fd); -} - -#endif /* HAVE_FD_TRANSFER */ - -#endif /* !MS_WINDOWS */ - - -/* - * All platforms - */ - -static PyObject* -Billiard_multiprocessing_address_of_buffer(PyObject *self, PyObject *obj) -{ - void *buffer; - Py_ssize_t buffer_len; - - if (PyObject_AsWriteBuffer(obj, &buffer, &buffer_len) < 0) - return NULL; - - return Py_BuildValue("N" F_PY_SSIZE_T, - PyLong_FromVoidPtr(buffer), buffer_len); -} - -#if !defined(MS_WINDOWS) - -static PyObject * -Billiard_read(PyObject *self, PyObject *args) -{ - int fd; - Py_buffer view; - Py_ssize_t buflen, recvlen = 0; - - char *buf = NULL; - - Py_ssize_t n = 0; - - if (!PyArg_ParseTuple(args, "iw*|n", &fd, &view, &recvlen)) - return NULL; - buflen = view.len; - buf = view.buf; - - if (recvlen < 0) { - PyBuffer_Release(&view); - PyErr_SetString(PyExc_ValueError, "negative len for read"); - return NULL; - } - - if (recvlen == 0) { - recvlen = buflen; - } - - if (buflen < recvlen) { - PyBuffer_Release(&view); - PyErr_SetString(PyExc_ValueError, - "Buffer too small for requested bytes"); - return NULL; - - } - - if (buflen < 0 || buflen == 0) { - errno = EINVAL; - goto bail; - } - // Requires Python 2.7 - //if (!_PyVerify_fd(fd)) goto bail; - - Py_BEGIN_ALLOW_THREADS - n = read(fd, buf, recvlen); - Py_END_ALLOW_THREADS - if (n < 0) goto bail; - PyBuffer_Release(&view); - return PyInt_FromSsize_t(n); - -bail: - PyBuffer_Release(&view); - return PyErr_SetFromErrno(PyExc_OSError); -} - -# endif /* !MS_WINDOWS */ - - +#endif /* * Function table */ -static PyMethodDef Billiard_module_methods[] = { - {"address_of_buffer", Billiard_multiprocessing_address_of_buffer, METH_O, - "address_of_buffer(obj) -> int\n\n" - "Return address of obj assuming obj supports buffer inteface"}, -#if HAVE_FD_TRANSFER - {"sendfd", Billiard_multiprocessing_sendfd, METH_VARARGS, - "sendfd(sockfd, fd) -> None\n\n" - "Send file descriptor given by fd over the unix domain socket\n" - "whose file descriptor is sockfd"}, - {"recvfd", Billiard_multiprocessing_recvfd, METH_VARARGS, - "recvfd(sockfd) -> fd\n\n" - "Receive a file descriptor over a unix domain socket\n" - "whose file descriptor is sockfd"}, -#endif -#if !defined(MS_WINDOWS) - {"read", Billiard_read, METH_VARARGS, - "read(fd, buffer) -> bytes\n\n" - "Read from file descriptor into buffer."}, -#endif +static PyMethodDef module_methods[] = { #ifdef MS_WINDOWS - {"closesocket", Billiard_closesocket, METH_VARARGS, ""}, - {"recv", Billiard_recv, METH_VARARGS, ""}, - {"send", Billiard_send, METH_VARARGS, ""}, + {"closesocket", multiprocessing_closesocket, METH_VARARGS, ""}, + {"recv", multiprocessing_recv, METH_VARARGS, ""}, + {"send", multiprocessing_send, METH_VARARGS, ""}, #endif -#ifndef POSIX_SEMAPHORES_NOT_ENABLED - {"sem_unlink", Billiard_semlock_unlink, METH_VARARGS, ""}, +#if !defined(POSIX_SEMAPHORES_NOT_ENABLED) && !defined(__ANDROID__) + {"sem_unlink", _PyMp_sem_unlink, METH_VARARGS, ""}, #endif {NULL} }; @@ -355,49 +139,64 @@ static PyMethodDef Billiard_module_methods[] = { * Initialize */ +static struct PyModuleDef multiprocessing_module = { + PyModuleDef_HEAD_INIT, + "_multiprocessing", + NULL, + -1, + module_methods, + NULL, + NULL, + NULL, + NULL +}; + + PyMODINIT_FUNC -init_billiard(void) +PyInit__multiprocessing(void) { - PyObject *module, *temp, *value; + PyObject *module, *temp, *value = NULL; /* Initialize module */ - module = Py_InitModule("_billiard", Billiard_module_methods); + module = PyModule_Create(&multiprocessing_module); if (!module) - return; + return NULL; #if defined(MS_WINDOWS) || \ (defined(HAVE_SEM_OPEN) && !defined(POSIX_SEMAPHORES_NOT_ENABLED)) - /* Add SemLock type to module */ - if (PyType_Ready(&BilliardSemLockType) < 0) - return; - Py_INCREF(&BilliardSemLockType); - PyDict_SetItemString(BilliardSemLockType.tp_dict, "SEM_VALUE_MAX", - Py_BuildValue("i", SEM_VALUE_MAX)); - PyModule_AddObject(module, "SemLock", (PyObject*)&BilliardSemLockType); -#endif - -#ifdef MS_WINDOWS - /* Initialize the event handle used to signal Ctrl-C */ - sigint_event = CreateEvent(NULL, TRUE, FALSE, NULL); - if (!sigint_event) { - PyErr_SetFromWindowsErr(0); - return; - } - if (!SetConsoleCtrlHandler(ProcessingCtrlHandler, TRUE)) { - PyErr_SetFromWindowsErr(0); - return; + /* Add _PyMp_SemLock type to module */ + if (PyType_Ready(&_PyMp_SemLockType) < 0) + return NULL; + Py_INCREF(&_PyMp_SemLockType); + { + PyObject *py_sem_value_max; + /* Some systems define SEM_VALUE_MAX as an unsigned value that + * causes it to be negative when used as an int (NetBSD). + * + * Issue #28152: Use (0) instead of 0 to fix a warning on dead code + * when using clang -Wunreachable-code. */ + if ((int)(SEM_VALUE_MAX) < (0)) + py_sem_value_max = PyLong_FromLong(INT_MAX); + else + py_sem_value_max = PyLong_FromLong(SEM_VALUE_MAX); + if (py_sem_value_max == NULL) + return NULL; + PyDict_SetItemString(_PyMp_SemLockType.tp_dict, "SEM_VALUE_MAX", + py_sem_value_max); } + PyModule_AddObject(module, "SemLock", (PyObject*)&_PyMp_SemLockType); #endif /* Add configuration macros */ temp = PyDict_New(); if (!temp) - return; + return NULL; + #define ADD_FLAG(name) \ value = Py_BuildValue("i", name); \ - if (value == NULL) { Py_DECREF(temp); return; } \ + if (value == NULL) { Py_DECREF(temp); return NULL; } \ if (PyDict_SetItemString(temp, #name, value) < 0) { \ - Py_DECREF(temp); Py_DECREF(value); return; } \ + Py_DECREF(temp); Py_DECREF(value); return NULL; } \ Py_DECREF(value) #if defined(HAVE_SEM_OPEN) && !defined(POSIX_SEMAPHORES_NOT_ENABLED) @@ -406,15 +205,15 @@ init_billiard(void) #ifdef HAVE_SEM_TIMEDWAIT ADD_FLAG(HAVE_SEM_TIMEDWAIT); #endif -#ifdef HAVE_FD_TRANSFER - ADD_FLAG(HAVE_FD_TRANSFER); -#endif #ifdef HAVE_BROKEN_SEM_GETVALUE ADD_FLAG(HAVE_BROKEN_SEM_GETVALUE); #endif #ifdef HAVE_BROKEN_SEM_UNLINK ADD_FLAG(HAVE_BROKEN_SEM_UNLINK); #endif + if (PyModule_AddObject(module, "flags", temp) < 0) - return; + return NULL; + + return module; } diff --git a/Modules/_billiard/multiprocessing.h b/Modules/_billiard/multiprocessing.h index 74d2f33d..fe78135d 100644 --- a/Modules/_billiard/multiprocessing.h +++ b/Modules/_billiard/multiprocessing.h @@ -3,12 +3,6 @@ #define PY_SSIZE_T_CLEAN -#ifdef __sun -/* The control message API is only available on Solaris - if XPG 4.2 or later is requested. */ -#define _XOPEN_SOURCE 500 -#endif - #include "Python.h" #include "structmember.h" #include "pythread.h" @@ -29,22 +23,10 @@ # define SEM_VALUE_MAX LONG_MAX #else # include /* O_CREAT and O_EXCL */ -# include -# include -# include -# include /* htonl() and ntohl() */ # if defined(HAVE_SEM_OPEN) && !defined(POSIX_SEMAPHORES_NOT_ENABLED) # include typedef sem_t *SEM_HANDLE; # endif -# define HANDLE long -# define SOCKET int -# define BOOL int -# define UINT32 uint32_t -# define INT32 int32_t -# define TRUE 1 -# define FALSE 0 -# define INVALID_HANDLE_VALUE (-1) #endif /* @@ -63,20 +45,6 @@ #endif -/* - * Make sure Py_ssize_t available - */ - -#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) - typedef int Py_ssize_t; -# define PY_SSIZE_T_MAX INT_MAX -# define PY_SSIZE_T_MIN INT_MIN -# define F_PY_SSIZE_T "i" -# define PyInt_FromSsize_t(n) PyInt_FromLong((long)n) -#else -# define F_PY_SSIZE_T "n" -#endif - /* * Format codes */ @@ -84,7 +52,7 @@ #if SIZEOF_VOID_P == SIZEOF_LONG # define F_POINTER "k" # define T_POINTER T_ULONG -#elif defined(HAVE_LONG_LONG) && (SIZEOF_VOID_P == SIZEOF_LONG_LONG) +#elif SIZEOF_VOID_P == SIZEOF_LONG_LONG # define F_POINTER "K" # define T_POINTER T_ULONGLONG #else @@ -96,8 +64,6 @@ # define T_HANDLE T_POINTER # define F_SEM_HANDLE F_HANDLE # define T_SEM_HANDLE T_HANDLE -# define F_DWORD "k" -# define T_DWORD T_ULONG #else # define F_HANDLE "i" # define T_HANDLE T_INT @@ -105,12 +71,6 @@ # define T_SEM_HANDLE T_POINTER #endif -#if PY_VERSION_HEX >= 0x03000000 -# define F_RBUFFER "y" -#else -# define F_RBUFFER "s" -#endif - /* * Error codes which can be returned by functions called without GIL */ @@ -118,67 +78,16 @@ #define MP_SUCCESS (0) #define MP_STANDARD_ERROR (-1) #define MP_MEMORY_ERROR (-1001) -#define MP_END_OF_FILE (-1002) -#define MP_EARLY_END_OF_FILE (-1003) -#define MP_BAD_MESSAGE_LENGTH (-1004) -#define MP_SOCKET_ERROR (-1005) -#define MP_EXCEPTION_HAS_BEEN_SET (-1006) +#define MP_SOCKET_ERROR (-1002) +#define MP_EXCEPTION_HAS_BEEN_SET (-1003) -PyObject *Billiard_SetError(PyObject *Type, int num); +PyObject *_PyMp_SetError(PyObject *Type, int num); /* * Externs - not all will really exist on all platforms */ -extern PyObject *Billiard_BufferTooShort; -extern PyTypeObject BilliardSemLockType; -extern PyObject *Billiard_semlock_unlink(PyObject *ignore, PyObject *args); -extern HANDLE sigint_event; - -/* - * Py3k compatibility - */ - -#if PY_VERSION_HEX >= 0x03000000 -# define PICKLE_MODULE "pickle" -# define FROM_FORMAT PyUnicode_FromFormat -# define PyInt_FromLong PyLong_FromLong -# define PyInt_FromSsize_t PyLong_FromSsize_t -#else -# define PICKLE_MODULE "cPickle" -# define FROM_FORMAT PyString_FromFormat -#endif - -#ifndef PyVarObject_HEAD_INIT -# define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, -#endif - -#ifndef Py_TPFLAGS_HAVE_WEAKREFS -# define Py_TPFLAGS_HAVE_WEAKREFS 0 -#endif - -/* - * Connection definition - */ - -#define CONNECTION_BUFFER_SIZE 131072 - -typedef struct { - PyObject_HEAD - HANDLE handle; - int flags; - PyObject *weakreflist; - char buffer[CONNECTION_BUFFER_SIZE]; -} BilliardConnectionObject; - -/* - * Miscellaneous - */ - -#define MAX_MESSAGE_LENGTH 0x7fffffff - -#ifndef Py_MIN -# define Py_MIN(x, y) (((x) > (y)) ? (y) : (x)) -#endif +extern PyTypeObject _PyMp_SemLockType; +extern PyObject *_PyMp_sem_unlink(PyObject *ignore, PyObject *args); #endif /* MULTIPROCESSING_H */ diff --git a/Modules/_billiard/posixshmem.c b/Modules/_billiard/posixshmem.c new file mode 100644 index 00000000..436ac6d6 --- /dev/null +++ b/Modules/_billiard/posixshmem.c @@ -0,0 +1,130 @@ +/* +posixshmem - A Python extension that provides shm_open() and shm_unlink() +*/ + +#define PY_SSIZE_T_CLEAN + +#include + +// for shm_open() and shm_unlink() +#ifdef HAVE_SYS_MMAN_H +#include +#endif + +/*[clinic input] +module _posixshmem +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=a416734e49164bf8]*/ + +/* + * + * Module-level functions & meta stuff + * + */ + +#ifdef HAVE_SHM_OPEN +/*[clinic input] +_posixshmem.shm_open -> int + path: unicode + flags: int + mode: int = 0o777 + +# "shm_open(path, flags, mode=0o777)\n\n\ + +Open a shared memory object. Returns a file descriptor (integer). + +[clinic start generated code]*/ + +static int +_posixshmem_shm_open_impl(PyObject *module, PyObject *path, int flags, + int mode) +/*[clinic end generated code: output=8d110171a4fa20df input=e83b58fa802fac25]*/ +{ + int fd; + int async_err = 0; + const char *name = PyUnicode_AsUTF8(path); + if (name == NULL) { + return -1; + } + do { + Py_BEGIN_ALLOW_THREADS + fd = shm_open(name, flags, mode); + Py_END_ALLOW_THREADS + } while (fd < 0 && errno == EINTR && !(async_err = PyErr_CheckSignals())); + + if (fd < 0) { + if (!async_err) + PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, path); + return -1; + } + + return fd; +} +#endif /* HAVE_SHM_OPEN */ + +#ifdef HAVE_SHM_UNLINK +/*[clinic input] +_posixshmem.shm_unlink + path: unicode + +Remove a shared memory object (similar to unlink()). + +Remove a shared memory object name, and, once all processes have unmapped +the object, de-allocates and destroys the contents of the associated memory +region. + +[clinic start generated code]*/ + +static PyObject * +_posixshmem_shm_unlink_impl(PyObject *module, PyObject *path) +/*[clinic end generated code: output=42f8b23d134b9ff5 input=8dc0f87143e3b300]*/ +{ + int rv; + int async_err = 0; + const char *name = PyUnicode_AsUTF8(path); + if (name == NULL) { + return NULL; + } + do { + Py_BEGIN_ALLOW_THREADS + rv = shm_unlink(name); + Py_END_ALLOW_THREADS + } while (rv < 0 && errno == EINTR && !(async_err = PyErr_CheckSignals())); + + if (rv < 0) { + if (!async_err) + PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, path); + return NULL; + } + + Py_RETURN_NONE; +} +#endif /* HAVE_SHM_UNLINK */ + +#include "clinic/posixshmem.c.h" + +static PyMethodDef module_methods[ ] = { + _POSIXSHMEM_SHM_OPEN_METHODDEF + _POSIXSHMEM_SHM_UNLINK_METHODDEF + {NULL} /* Sentinel */ +}; + + +static struct PyModuleDef this_module = { + PyModuleDef_HEAD_INIT, // m_base + "_posixshmem", // m_name + "POSIX shared memory module", // m_doc + -1, // m_size (space allocated for module globals) + module_methods, // m_methods +}; + +/* Module init function */ +PyMODINIT_FUNC +PyInit__posixshmem(void) { + PyObject *module; + module = PyModule_Create(&this_module); + if (!module) { + return NULL; + } + return module; +} diff --git a/Modules/_billiard/semaphore.c b/Modules/_billiard/semaphore.c index e780561e..ee490256 100644 --- a/Modules/_billiard/semaphore.c +++ b/Modules/_billiard/semaphore.c @@ -14,12 +14,12 @@ enum { RECURSIVE_MUTEX, SEMAPHORE }; typedef struct { PyObject_HEAD SEM_HANDLE handle; - long last_tid; + unsigned long last_tid; int count; int maxvalue; int kind; char *name; -} BilliardSemLockObject; +} SemLockObject; #define ISMINE(o) (o->count > 0 && PyThread_get_thread_ident() == o->last_tid) @@ -36,11 +36,11 @@ typedef struct { #define SEM_GET_LAST_ERROR() GetLastError() #define SEM_CREATE(name, val, max) CreateSemaphore(NULL, val, max, NULL) #define SEM_CLOSE(sem) (CloseHandle(sem) ? 0 : -1) -#define SEM_GETVALUE(sem, pval) _Billiard_GetSemaphoreValue(sem, pval) +#define SEM_GETVALUE(sem, pval) _GetSemaphoreValue(sem, pval) #define SEM_UNLINK(name) 0 static int -_Billiard_GetSemaphoreValue(HANDLE handle, long *value) +_GetSemaphoreValue(HANDLE handle, long *value) { long previous; @@ -59,12 +59,13 @@ _Billiard_GetSemaphoreValue(HANDLE handle, long *value) } static PyObject * -Billiard_semlock_acquire(BilliardSemLockObject *self, PyObject *args, PyObject *kwds) +semlock_acquire(SemLockObject *self, PyObject *args, PyObject *kwds) { int blocking = 1; double timeout; PyObject *timeout_obj = Py_None; - DWORD res, full_msecs, msecs, start, ticks; + DWORD res, full_msecs, nhandles; + HANDLE handles[2], sigint_event; static char *kwlist[] = {"block", "timeout", NULL}; @@ -98,65 +99,55 @@ Billiard_semlock_acquire(BilliardSemLockObject *self, PyObject *args, PyObject * Py_RETURN_TRUE; } - /* check whether we can acquire without blocking */ + /* check whether we can acquire without releasing the GIL and blocking */ if (WaitForSingleObjectEx(self->handle, 0, FALSE) == WAIT_OBJECT_0) { self->last_tid = GetCurrentThreadId(); ++self->count; Py_RETURN_TRUE; } - msecs = full_msecs; - start = GetTickCount(); - - for ( ; ; ) { - HANDLE handles[2] = {self->handle, sigint_event}; + /* prepare list of handles */ + nhandles = 0; + handles[nhandles++] = self->handle; + if (_PyOS_IsMainThread()) { + sigint_event = _PyOS_SigintEvent(); + assert(sigint_event != NULL); + handles[nhandles++] = sigint_event; + } + else { + sigint_event = NULL; + } - /* do the wait */ - Py_BEGIN_ALLOW_THREADS + /* do the wait */ + Py_BEGIN_ALLOW_THREADS + if (sigint_event != NULL) ResetEvent(sigint_event); - res = WaitForMultipleObjectsEx(2, handles, FALSE, msecs, FALSE); - Py_END_ALLOW_THREADS - - /* handle result */ - if (res != WAIT_OBJECT_0 + 1) - break; - - /* got SIGINT so give signal handler a chance to run */ - Sleep(1); - - /* if this is main thread let KeyboardInterrupt be raised */ - if (PyErr_CheckSignals()) - return NULL; - - /* recalculate timeout */ - if (msecs != INFINITE) { - ticks = GetTickCount(); - if ((DWORD)(ticks - start) >= full_msecs) - Py_RETURN_FALSE; - msecs = full_msecs - (ticks - start); - } - } + res = WaitForMultipleObjectsEx(nhandles, handles, FALSE, full_msecs, FALSE); + Py_END_ALLOW_THREADS /* handle result */ switch (res) { case WAIT_TIMEOUT: Py_RETURN_FALSE; - case WAIT_OBJECT_0: + case WAIT_OBJECT_0 + 0: self->last_tid = GetCurrentThreadId(); ++self->count; Py_RETURN_TRUE; + case WAIT_OBJECT_0 + 1: + errno = EINTR; + return PyErr_SetFromErrno(PyExc_OSError); case WAIT_FAILED: return PyErr_SetFromWindowsErr(0); default: - PyErr_Format(PyExc_RuntimeError, "WaitForSingleObjectEx() or " + PyErr_Format(PyExc_RuntimeError, "WaitForSingleObject() or " "WaitForMultipleObjects() gave unrecognized " - "value %d", res); + "value %u", res); return NULL; } } static PyObject * -Billiard_semlock_release(BilliardSemLockObject *self, PyObject *args) +semlock_release(SemLockObject *self, PyObject *args) { if (self->kind == RECURSIVE_MUTEX) { if (!ISMINE(self)) { @@ -199,22 +190,22 @@ Billiard_semlock_release(BilliardSemLockObject *self, PyObject *args) #define SEM_GETVALUE(sem, pval) sem_getvalue(sem, pval) #define SEM_UNLINK(name) sem_unlink(name) -/* macOS 10.4 defines SEM_FAILED as -1 instead (sem_t *)-1; this gives - compiler warnings, and (potentially) undefined behavior. */ +/* OS X 10.4 defines SEM_FAILED as -1 instead of (sem_t *)-1; this gives + compiler warnings, and (potentially) undefined behaviour. */ #ifdef __APPLE__ -# undef SEM_FAILED -# define SEM_FAILED ((sem_t *)-1) +# undef SEM_FAILED +# define SEM_FAILED ((sem_t *)-1) #endif #ifndef HAVE_SEM_UNLINK # define sem_unlink(name) 0 #endif -//#ifndef HAVE_SEM_TIMEDWAIT -# define sem_timedwait(sem,deadline) Billiard_sem_timedwait_save(sem,deadline,_save) +#ifndef HAVE_SEM_TIMEDWAIT +# define sem_timedwait(sem,deadline) sem_timedwait_save(sem,deadline,_save) -int -Billiard_sem_timedwait_save(sem_t *sem, struct timespec *deadline, PyThreadState *_save) +static int +sem_timedwait_save(sem_t *sem, struct timespec *deadline, PyThreadState *_save) { int res; unsigned long delay, difference; @@ -271,17 +262,14 @@ Billiard_sem_timedwait_save(sem_t *sem, struct timespec *deadline, PyThreadState } } -//#endif /* !HAVE_SEM_TIMEDWAIT */ +#endif /* !HAVE_SEM_TIMEDWAIT */ static PyObject * -Billiard_semlock_acquire(BilliardSemLockObject *self, PyObject *args, PyObject *kwds) +semlock_acquire(SemLockObject *self, PyObject *args, PyObject *kwds) { - int blocking = 1, res; - double timeout; + int blocking = 1, res, err = 0; PyObject *timeout_obj = Py_None; struct timespec deadline = {0}; - struct timeval now; - long sec, nsec; static char *kwlist[] = {"block", "timeout", NULL}; @@ -294,39 +282,55 @@ Billiard_semlock_acquire(BilliardSemLockObject *self, PyObject *args, PyObject * Py_RETURN_TRUE; } - if (timeout_obj != Py_None) { - timeout = PyFloat_AsDouble(timeout_obj); - if (PyErr_Occurred()) + int use_deadline = (timeout_obj != Py_None); + if (use_deadline) { + double timeout = PyFloat_AsDouble(timeout_obj); + if (PyErr_Occurred()) { return NULL; - if (timeout < 0.0) + } + if (timeout < 0.0) { timeout = 0.0; + } + struct timeval now; if (gettimeofday(&now, NULL) < 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; } - sec = (long) timeout; - nsec = (long) (1e9 * (timeout - sec) + 0.5); + long sec = (long) timeout; + long nsec = (long) (1e9 * (timeout - sec) + 0.5); deadline.tv_sec = now.tv_sec + sec; deadline.tv_nsec = now.tv_usec * 1000 + nsec; deadline.tv_sec += (deadline.tv_nsec / 1000000000); deadline.tv_nsec %= 1000000000; } + /* Check whether we can acquire without releasing the GIL and blocking */ do { - Py_BEGIN_ALLOW_THREADS - if (blocking && timeout_obj == Py_None) - res = sem_wait(self->handle); - else if (!blocking) - res = sem_trywait(self->handle); - else - res = sem_timedwait(self->handle, &deadline); - Py_END_ALLOW_THREADS - if (res == MP_EXCEPTION_HAS_BEEN_SET) - break; + res = sem_trywait(self->handle); + err = errno; } while (res < 0 && errno == EINTR && !PyErr_CheckSignals()); + errno = err; + + if (res < 0 && errno == EAGAIN && blocking) { + /* Couldn't acquire immediately, need to block */ + do { + Py_BEGIN_ALLOW_THREADS + if (!use_deadline) { + res = sem_wait(self->handle); + } + else { + res = sem_timedwait(self->handle, &deadline); + } + Py_END_ALLOW_THREADS + err = errno; + if (res == MP_EXCEPTION_HAS_BEEN_SET) + break; + } while (res < 0 && errno == EINTR && !PyErr_CheckSignals()); + } if (res < 0) { + errno = err; if (errno == EAGAIN || errno == ETIMEDOUT) Py_RETURN_FALSE; else if (errno == EINTR) @@ -342,7 +346,7 @@ Billiard_semlock_acquire(BilliardSemLockObject *self, PyObject *args, PyObject * } static PyObject * -Billiard_semlock_release(BilliardSemLockObject *self, PyObject *args) +semlock_release(SemLockObject *self, PyObject *args) { if (self->kind == RECURSIVE_MUTEX) { if (!ISMINE(self)) { @@ -408,12 +412,12 @@ Billiard_semlock_release(BilliardSemLockObject *self, PyObject *args) */ static PyObject * -Billiard_newsemlockobject(PyTypeObject *type, SEM_HANDLE handle, int kind, int maxvalue, +newsemlockobject(PyTypeObject *type, SEM_HANDLE handle, int kind, int maxvalue, char *name) { - BilliardSemLockObject *self; + SemLockObject *self; - self = PyObject_New(BilliardSemLockObject, type); + self = PyObject_New(SemLockObject, type); if (!self) return NULL; self->handle = handle; @@ -426,7 +430,7 @@ Billiard_newsemlockobject(PyTypeObject *type, SEM_HANDLE handle, int kind, int m } static PyObject * -Billiard_semlock_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +semlock_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { SEM_HANDLE handle = SEM_FAILED; int kind, maxvalue, value, unlink; @@ -446,8 +450,9 @@ Billiard_semlock_new(PyTypeObject *type, PyObject *args, PyObject *kwds) if (!unlink) { name_copy = PyMem_Malloc(strlen(name) + 1); - if (name_copy == NULL) - goto failure; + if (name_copy == NULL) { + return PyErr_NoMemory(); + } strcpy(name_copy, name); } @@ -460,7 +465,7 @@ Billiard_semlock_new(PyTypeObject *type, PyObject *args, PyObject *kwds) if (unlink && SEM_UNLINK(name) < 0) goto failure; - result = Billiard_newsemlockobject(type, handle, kind, maxvalue, name_copy); + result = newsemlockobject(type, handle, kind, maxvalue, name_copy); if (!result) goto failure; @@ -470,12 +475,14 @@ Billiard_semlock_new(PyTypeObject *type, PyObject *args, PyObject *kwds) if (handle != SEM_FAILED) SEM_CLOSE(handle); PyMem_Free(name_copy); - Billiard_SetError(NULL, MP_STANDARD_ERROR); + if (!PyErr_Occurred()) { + _PyMp_SetError(NULL, MP_STANDARD_ERROR); + } return NULL; } static PyObject * -Billiard_semlock_rebuild(PyTypeObject *type, PyObject *args) +semlock_rebuild(PyTypeObject *type, PyObject *args) { SEM_HANDLE handle; int kind, maxvalue; @@ -502,11 +509,11 @@ Billiard_semlock_rebuild(PyTypeObject *type, PyObject *args) } #endif - return Billiard_newsemlockobject(type, handle, kind, maxvalue, name_copy); + return newsemlockobject(type, handle, kind, maxvalue, name_copy); } static void -Billiard_semlock_dealloc(BilliardSemLockObject* self) +semlock_dealloc(SemLockObject* self) { if (self->handle != SEM_FAILED) SEM_CLOSE(self->handle); @@ -515,20 +522,20 @@ Billiard_semlock_dealloc(BilliardSemLockObject* self) } static PyObject * -Billiard_semlock_count(BilliardSemLockObject *self) +semlock_count(SemLockObject *self, PyObject *Py_UNUSED(ignored)) { - return PyInt_FromLong((long)self->count); + return PyLong_FromLong((long)self->count); } static PyObject * -Billiard_semlock_ismine(BilliardSemLockObject *self) +semlock_ismine(SemLockObject *self, PyObject *Py_UNUSED(ignored)) { /* only makes sense for a lock */ return PyBool_FromLong(ISMINE(self)); } static PyObject * -Billiard_semlock_getvalue(BilliardSemLockObject *self) +semlock_getvalue(SemLockObject *self, PyObject *Py_UNUSED(ignored)) { #ifdef HAVE_BROKEN_SEM_GETVALUE PyErr_SetNone(PyExc_NotImplementedError); @@ -536,87 +543,68 @@ Billiard_semlock_getvalue(BilliardSemLockObject *self) #else int sval; if (SEM_GETVALUE(self->handle, &sval) < 0) - return Billiard_SetError(NULL, MP_STANDARD_ERROR); + return _PyMp_SetError(NULL, MP_STANDARD_ERROR); /* some posix implementations use negative numbers to indicate the number of waiting threads */ if (sval < 0) sval = 0; - return PyInt_FromLong((long)sval); + return PyLong_FromLong((long)sval); #endif } static PyObject * -Billiard_semlock_iszero(BilliardSemLockObject *self) +semlock_iszero(SemLockObject *self, PyObject *Py_UNUSED(ignored)) { #ifdef HAVE_BROKEN_SEM_GETVALUE if (sem_trywait(self->handle) < 0) { if (errno == EAGAIN) Py_RETURN_TRUE; - return Billiard_SetError(NULL, MP_STANDARD_ERROR); + return _PyMp_SetError(NULL, MP_STANDARD_ERROR); } else { if (sem_post(self->handle) < 0) - return Billiard_SetError(NULL, MP_STANDARD_ERROR); + return _PyMp_SetError(NULL, MP_STANDARD_ERROR); Py_RETURN_FALSE; } #else int sval; if (SEM_GETVALUE(self->handle, &sval) < 0) - return Billiard_SetError(NULL, MP_STANDARD_ERROR); + return _PyMp_SetError(NULL, MP_STANDARD_ERROR); return PyBool_FromLong((long)sval == 0); #endif } static PyObject * -Billiard_semlock_afterfork(BilliardSemLockObject *self) +semlock_afterfork(SemLockObject *self, PyObject *Py_UNUSED(ignored)) { self->count = 0; Py_RETURN_NONE; } -PyObject * -Billiard_semlock_unlink(PyObject *ignore, PyObject *args) -{ - char *name; - - if (!PyArg_ParseTuple(args, "s", &name)) - return NULL; - - if (SEM_UNLINK(name) < 0) { - Billiard_SetError(NULL, MP_STANDARD_ERROR); - return NULL; - } - - Py_RETURN_NONE; -} - /* * Semaphore methods */ -static PyMethodDef Billiard_semlock_methods[] = { - {"acquire", (PyCFunction)Billiard_semlock_acquire, METH_VARARGS | METH_KEYWORDS, +static PyMethodDef semlock_methods[] = { + {"acquire", (PyCFunction)(void(*)(void))semlock_acquire, METH_VARARGS | METH_KEYWORDS, "acquire the semaphore/lock"}, - {"release", (PyCFunction)Billiard_semlock_release, METH_NOARGS, + {"release", (PyCFunction)semlock_release, METH_NOARGS, "release the semaphore/lock"}, - {"__enter__", (PyCFunction)Billiard_semlock_acquire, METH_VARARGS | METH_KEYWORDS, + {"__enter__", (PyCFunction)(void(*)(void))semlock_acquire, METH_VARARGS | METH_KEYWORDS, "enter the semaphore/lock"}, - {"__exit__", (PyCFunction)Billiard_semlock_release, METH_VARARGS, + {"__exit__", (PyCFunction)semlock_release, METH_VARARGS, "exit the semaphore/lock"}, - {"_count", (PyCFunction)Billiard_semlock_count, METH_NOARGS, + {"_count", (PyCFunction)semlock_count, METH_NOARGS, "num of `acquire()`s minus num of `release()`s for this process"}, - {"_is_mine", (PyCFunction)Billiard_semlock_ismine, METH_NOARGS, + {"_is_mine", (PyCFunction)semlock_ismine, METH_NOARGS, "whether the lock is owned by this thread"}, - {"_get_value", (PyCFunction)Billiard_semlock_getvalue, METH_NOARGS, + {"_get_value", (PyCFunction)semlock_getvalue, METH_NOARGS, "get the value of the semaphore"}, - {"_is_zero", (PyCFunction)Billiard_semlock_iszero, METH_NOARGS, + {"_is_zero", (PyCFunction)semlock_iszero, METH_NOARGS, "returns whether semaphore has value zero"}, - {"_rebuild", (PyCFunction)Billiard_semlock_rebuild, METH_VARARGS | METH_CLASS, + {"_rebuild", (PyCFunction)semlock_rebuild, METH_VARARGS | METH_CLASS, ""}, - {"_after_fork", (PyCFunction)Billiard_semlock_afterfork, METH_NOARGS, + {"_after_fork", (PyCFunction)semlock_afterfork, METH_NOARGS, "rezero the net acquisition count after fork()"}, - {"sem_unlink", (PyCFunction)Billiard_semlock_unlink, METH_VARARGS | METH_STATIC, - "unlink the named semaphore using sem_unlink()"}, - {NULL} }; @@ -624,14 +612,14 @@ static PyMethodDef Billiard_semlock_methods[] = { * Member table */ -static PyMemberDef Billiard_semlock_members[] = { - {"handle", T_SEM_HANDLE, offsetof(BilliardSemLockObject, handle), READONLY, +static PyMemberDef semlock_members[] = { + {"handle", T_SEM_HANDLE, offsetof(SemLockObject, handle), READONLY, ""}, - {"kind", T_INT, offsetof(BilliardSemLockObject, kind), READONLY, + {"kind", T_INT, offsetof(SemLockObject, kind), READONLY, ""}, - {"maxvalue", T_INT, offsetof(BilliardSemLockObject, maxvalue), READONLY, + {"maxvalue", T_INT, offsetof(SemLockObject, maxvalue), READONLY, ""}, - {"name", T_STRING, offsetof(BilliardSemLockObject, name), READONLY, + {"name", T_STRING, offsetof(SemLockObject, name), READONLY, ""}, {NULL} }; @@ -640,16 +628,16 @@ static PyMemberDef Billiard_semlock_members[] = { * Semaphore type */ -PyTypeObject BilliardSemLockType = { +PyTypeObject _PyMp_SemLockType = { PyVarObject_HEAD_INIT(NULL, 0) - /* tp_name */ "_billiard.SemLock", - /* tp_basicsize */ sizeof(BilliardSemLockObject), + /* tp_name */ "_multiprocessing.SemLock", + /* tp_basicsize */ sizeof(SemLockObject), /* tp_itemsize */ 0, - /* tp_dealloc */ (destructor)Billiard_semlock_dealloc, - /* tp_print */ 0, + /* tp_dealloc */ (destructor)semlock_dealloc, + /* tp_vectorcall_offset */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, - /* tp_compare */ 0, + /* tp_as_async */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, @@ -668,8 +656,8 @@ PyTypeObject BilliardSemLockType = { /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ 0, - /* tp_methods */ Billiard_semlock_methods, - /* tp_members */ Billiard_semlock_members, + /* tp_methods */ semlock_methods, + /* tp_members */ semlock_members, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, @@ -678,5 +666,25 @@ PyTypeObject BilliardSemLockType = { /* tp_dictoffset */ 0, /* tp_init */ 0, /* tp_alloc */ 0, - /* tp_new */ Billiard_semlock_new, + /* tp_new */ semlock_new, }; + +/* + * Function to unlink semaphore names + */ + +PyObject * +_PyMp_sem_unlink(PyObject *ignore, PyObject *args) +{ + char *name; + + if (!PyArg_ParseTuple(args, "s", &name)) + return NULL; + + if (SEM_UNLINK(name) < 0) { + _PyMp_SetError(NULL, MP_STANDARD_ERROR); + return NULL; + } + + Py_RETURN_NONE; +} diff --git a/billiard/__init__.py b/billiard/__init__.py index 642f823c..8336f381 100644 --- a/billiard/__init__.py +++ b/billiard/__init__.py @@ -1,4 +1,3 @@ -"""Python multiprocessing fork with improvements and bugfixes""" # # Package analogous to 'threading.py' but using processes # @@ -9,37 +8,19 @@ # subpackage 'multiprocessing.dummy' has the same API but is a simple # wrapper for 'threading'. # -# Try calling `multiprocessing.doc.main()` to read the html -# documentation in a webbrowser. -# -# # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import - import sys from . import context -VERSION = (3, 6, 1, 0) -__version__ = '.'.join(map(str, VERSION[0:4])) + "".join(VERSION[4:]) -__author__ = 'R Oudkerk / Python Software Foundation' -__author_email__ = 'python-dev@python.org' -__maintainer__ = 'Ask Solem' -__contact__ = "ask@celeryproject.org" -__homepage__ = "https://github.com/celery/billiard" -__docformat__ = "restructuredtext" - -# -eof meta- - # # Copy stuff from default context # -globals().update((name, getattr(context._default_context, name)) - for name in context._default_context.__all__) -__all__ = context._default_context.__all__ +__all__ = [x for x in dir(context._default_context) if not x.startswith('_')] +globals().update((name, getattr(context._default_context, name)) for name in __all__) # # XXX These should not really be documented or public. @@ -54,8 +35,3 @@ if '__main__' in sys.modules: sys.modules['__mp_main__'] = sys.modules['__main__'] - - -def ensure_multiprocessing(): - from ._ext import ensure_multiprocessing - return ensure_multiprocessing() diff --git a/billiard/connection.py b/billiard/connection.py index 627ed9b6..510e4b5a 100644 --- a/billiard/connection.py +++ b/billiard/connection.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # A higher level module for using sockets (or Windows named pipes) # @@ -8,44 +7,32 @@ # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import +__all__ = [ 'Client', 'Listener', 'Pipe', 'wait' ] -import errno import io import os import sys import socket -import select import struct +import time import tempfile import itertools -from . import reduction +import _multiprocessing + from . import util from . import AuthenticationError, BufferTooShort -from ._ext import _billiard -from .compat import setblocking, send_offset -from .five import monotonic -from .reduction import ForkingPickler +from .context import reduction +_ForkingPickler = reduction.ForkingPickler try: - from .compat import _winapi + import _winapi + from _winapi import WAIT_OBJECT_0, WAIT_ABANDONED_0, WAIT_TIMEOUT, INFINITE except ImportError: if sys.platform == 'win32': raise _winapi = None -else: - if sys.platform == 'win32': - WAIT_OBJECT_0 = _winapi.WAIT_OBJECT_0 - WAIT_ABANDONED_0 = _winapi.WAIT_ABANDONED_0 - - WAIT_TIMEOUT = _winapi.WAIT_TIMEOUT - INFINITE = _winapi.INFINITE - -__all__ = ['Client', 'Listener', 'Pipe', 'wait'] - -is_pypy = hasattr(sys, 'pypy_version_info') # # @@ -70,17 +57,15 @@ def _init_timeout(timeout=CONNECTION_TIMEOUT): - return monotonic() + timeout - + return time.monotonic() + timeout def _check_timeout(t): - return monotonic() > t + return time.monotonic() > t # # # - def arbitrary_address(family): ''' Return an arbitrary free address for the given family @@ -88,6 +73,11 @@ def arbitrary_address(family): if family == 'AF_INET': return ('localhost', 0) elif family == 'AF_UNIX': + # Prefer abstract sockets if possible to avoid problems with the address + # size. When coding portable applications, some implementations have + # sun_path as short as 92 bytes in the sockaddr_un struct. + if util.abstract_sockets_supported: + return f"\0listener-{os.getpid()}-{next(_mmap_counter)}" return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir()) elif family == 'AF_PIPE': return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' % @@ -95,7 +85,6 @@ def arbitrary_address(family): else: raise ValueError('unrecognized family') - def _validate_family(family): ''' Checks if the family is valid for the current environment. @@ -108,7 +97,6 @@ def _validate_family(family): if not hasattr(socket, family): raise ValueError('Family %s is not recognized.' % family) - def address_type(address): ''' Return the types of the address @@ -119,7 +107,7 @@ def address_type(address): return 'AF_INET' elif type(address) is str and address.startswith('\\\\'): return 'AF_PIPE' - elif type(address) is str: + elif type(address) is str or util.is_abstract_socket_namespace(address): return 'AF_UNIX' else: raise ValueError('address type of %r unrecognized' % address) @@ -128,20 +116,10 @@ def address_type(address): # Connection classes # - -class _SocketContainer(object): - - def __init__(self, sock): - self.sock = sock - - -class _ConnectionBase(object): +class _ConnectionBase: _handle = None def __init__(self, handle, readable=True, writable=True): - if isinstance(handle, _SocketContainer): - self._socket = handle.sock # keep ref so not collected - handle = handle.sock.fileno() handle = handle.__index__() if handle < 0: raise ValueError("invalid handle") @@ -230,7 +208,7 @@ def send(self, obj): """Send a (picklable) object""" self._check_closed() self._check_writable() - self._send_bytes(ForkingPickler.dumps(obj)) + self._send_bytes(_ForkingPickler.dumps(obj)) def recv_bytes(self, maxlength=None): """ @@ -266,9 +244,8 @@ def recv_bytes_into(self, buf, offset=0): raise BufferTooShort(result.getvalue()) # Message can fit in dest result.seek(0) - result.readinto(m[ - offset // itemsize:(offset + size) // itemsize - ]) + result.readinto(m[offset // itemsize : + (offset + size) // itemsize]) return size def recv(self): @@ -276,7 +253,7 @@ def recv(self): self._check_closed() self._check_readable() buf = self._recv_bytes() - return ForkingPickler.loadbuf(buf) + return _ForkingPickler.loads(buf.getbuffer()) def poll(self, timeout=0.0): """Whether there is any input available to be read""" @@ -290,12 +267,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_tb): self.close() - def send_offset(self, buf, offset): - return send_offset(self.fileno(), buf, offset) - - def setblocking(self, blocking): - setblocking(self.fileno(), blocking) - if _winapi: @@ -315,7 +286,6 @@ def _send_bytes(self, buf): try: if err == _winapi.ERROR_IO_PENDING: waitres = _winapi.WaitForMultipleObjects( - [ov.event], False, INFINITE) assert waitres == WAIT_OBJECT_0 except: @@ -333,9 +303,8 @@ def _recv_bytes(self, maxsize=None): else: bsize = 128 if maxsize is None else min(maxsize, 128) try: - ov, err = _winapi.ReadFile( - self._handle, bsize, overlapped=True, - ) + ov, err = _winapi.ReadFile(self._handle, bsize, + overlapped=True) try: if err == _winapi.ERROR_IO_PENDING: waitres = _winapi.WaitForMultipleObjects( @@ -357,12 +326,11 @@ def _recv_bytes(self, maxsize=None): raise EOFError else: raise - raise RuntimeError( - "shouldn't get here; expected KeyboardInterrupt") + raise RuntimeError("shouldn't get here; expected KeyboardInterrupt") def _poll(self, timeout): if (self._got_empty_message or - _winapi.PeekNamedPipe(self._handle)[0] != 0): + _winapi.PeekNamedPipe(self._handle)[0] != 0): return True return bool(wait([self], timeout)) @@ -389,10 +357,10 @@ class Connection(_ConnectionBase): """ if _winapi: - def _close(self, _close=_billiard.closesocket): + def _close(self, _close=_multiprocessing.closesocket): _close(self._handle) - _write = _billiard.send - _read = _billiard.recv + _write = _multiprocessing.send + _read = _multiprocessing.recv else: def _close(self, _close=os.close): _close(self._handle) @@ -402,59 +370,57 @@ def _close(self, _close=os.close): def _send(self, buf, write=_write): remaining = len(buf) while True: - try: - n = write(self._handle, buf) - except (OSError, IOError, socket.error) as exc: - if getattr(exc, 'errno', None) != errno.EINTR: - raise - else: - remaining -= n - if remaining == 0: - break - buf = buf[n:] + n = write(self._handle, buf) + remaining -= n + if remaining == 0: + break + buf = buf[n:] def _recv(self, size, read=_read): buf = io.BytesIO() handle = self._handle remaining = size while remaining > 0: - try: - chunk = read(handle, remaining) - except (OSError, IOError, socket.error) as exc: - if getattr(exc, 'errno', None) != errno.EINTR: - raise - else: - n = len(chunk) - if n == 0: - if remaining == size: - raise EOFError - else: - raise OSError("got end of file during message") - buf.write(chunk) - remaining -= n + chunk = read(handle, remaining) + n = len(chunk) + if n == 0: + if remaining == size: + raise EOFError + else: + raise OSError("got end of file during message") + buf.write(chunk) + remaining -= n return buf - def _send_bytes(self, buf, memoryview=memoryview): + def _send_bytes(self, buf): n = len(buf) - # For wire compatibility with 3.2 and lower - header = struct.pack("!i", n) - if n > 16384: - # The payload is large so Nagle's algorithm won't be triggered - # and we'd better avoid the cost of concatenation. + if n > 0x7fffffff: + pre_header = struct.pack("!i", -1) + header = struct.pack("!Q", n) + self._send(pre_header) self._send(header) self._send(buf) else: - # Issue #20540: concatenate before sending, to avoid delays due - # to Nagle's algorithm on a TCP socket. - # Also note we want to avoid sending a 0-length buffer separately, - # to avoid "broken pipe" errors if the other end closed the pipe. - if isinstance(buf, memoryview): - buf = buf.tobytes() - self._send(header + buf) + # For wire compatibility with 3.7 and lower + header = struct.pack("!i", n) + if n > 16384: + # The payload is large so Nagle's algorithm won't be triggered + # and we'd better avoid the cost of concatenation. + self._send(header) + self._send(buf) + else: + # Issue #20540: concatenate before sending, to avoid delays due + # to Nagle's algorithm on a TCP socket. + # Also note we want to avoid sending a 0-length buffer separately, + # to avoid "broken pipe" errors if the other end closed the pipe. + self._send(header + buf) def _recv_bytes(self, maxsize=None): buf = self._recv(4) size, = struct.unpack("!i", buf.getvalue()) + if size == -1: + buf = self._recv(8) + size, = struct.unpack("!Q", buf.getvalue()) if maxsize is not None and size > maxsize: return None return self._recv(size) @@ -476,8 +442,8 @@ class Listener(object): connections, or for a Windows named pipe. ''' def __init__(self, address=None, family=None, backlog=1, authkey=None): - family = (family or - (address and address_type(address)) or default_family) + family = family or (address and address_type(address)) \ + or default_family address = address or arbitrary_address(family) _validate_family(family) @@ -514,8 +480,13 @@ def close(self): self._listener = None listener.close() - address = property(lambda self: self._listener._address) - last_accepted = property(lambda self: self._listener._last_accepted) + @property + def address(self): + return self._listener._address + + @property + def last_accepted(self): + return self._listener._last_accepted def __enter__(self): return self @@ -545,32 +516,20 @@ def Client(address, family=None, authkey=None): return c -def detach(sock): - if hasattr(sock, 'detach'): - return sock.detach() - # older socket lib does not have detach. We'll keep a reference around - # so that it does not get garbage collected. - return _SocketContainer(sock) - - if sys.platform != 'win32': - def Pipe(duplex=True, rnonblock=False, wnonblock=False): + def Pipe(duplex=True): ''' Returns pair of connection objects at either end of a pipe ''' if duplex: s1, s2 = socket.socketpair() - s1.setblocking(not rnonblock) - s2.setblocking(not wnonblock) - c1 = Connection(detach(s1)) - c2 = Connection(detach(s2)) + s1.setblocking(True) + s2.setblocking(True) + c1 = Connection(s1.detach()) + c2 = Connection(s2.detach()) else: fd1, fd2 = os.pipe() - if rnonblock: - setblocking(fd1, 0) - if wnonblock: - setblocking(fd2, 0) c1 = Connection(fd1, writable=False) c2 = Connection(fd2, readable=False) @@ -578,12 +537,10 @@ def Pipe(duplex=True, rnonblock=False, wnonblock=False): else: - def Pipe(duplex=True, rnonblock=False, wnonblock=False): + def Pipe(duplex=True): ''' Returns pair of connection objects at either end of a pipe ''' - assert not rnonblock, 'rnonblock not supported on windows' - assert not wnonblock, 'wnonblock not supported on windows' address = arbitrary_address('AF_PIPE') if duplex: openmode = _winapi.PIPE_ACCESS_DUPLEX @@ -602,14 +559,14 @@ def Pipe(duplex=True, rnonblock=False, wnonblock=False): 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, # default security descriptor: the handle cannot be inherited _winapi.NULL - ) + ) h2 = _winapi.CreateFile( address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL - ) + ) _winapi.SetNamedPipeHandleState( h2, _winapi.PIPE_READMODE_MESSAGE, None, None - ) + ) overlapped = _winapi.ConnectNamedPipe(h1, overlapped=True) _, err = overlapped.GetOverlappedResult(True) @@ -624,7 +581,6 @@ def Pipe(duplex=True, rnonblock=False, wnonblock=False): # Definitions for connections based on sockets # - class SocketListener(object): ''' Representation of a socket which is bound to an address and listening @@ -646,24 +602,18 @@ def __init__(self, address, family, backlog=1): self._family = family self._last_accepted = None - if family == 'AF_UNIX': + if family == 'AF_UNIX' and not util.is_abstract_socket_namespace(address): + # Linux abstract socket namespaces do not need to be explicitly unlinked self._unlink = util.Finalize( self, os.unlink, args=(address,), exitpriority=0 - ) + ) else: self._unlink = None def accept(self): - while True: - try: - s, self._last_accepted = self._socket.accept() - except (OSError, IOError, socket.error) as exc: - if getattr(exc, 'errno', None) != errno.EINTR: - raise - else: - break + s, self._last_accepted = self._socket.accept() s.setblocking(True) - return Connection(detach(s)) + return Connection(s.detach()) def close(self): try: @@ -680,10 +630,10 @@ def SocketClient(address): Return a connection object connected to the socket given by `address` ''' family = address_type(address) - s = socket.socket(getattr(socket, family)) - s.setblocking(True) - s.connect(address) - return Connection(detach(s)) + with socket.socket( getattr(socket, family) ) as s: + s.setblocking(True) + s.connect(address) + return Connection(s.detach()) # # Definitions for connections based on named pipes @@ -704,7 +654,7 @@ def __init__(self, address, backlog=None): self.close = util.Finalize( self, PipeListener._finalize_pipe_listener, args=(self._handle_queue, self._address), exitpriority=0 - ) + ) def _new_handle(self, first=False): flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED @@ -716,7 +666,7 @@ def _new_handle(self, first=False): _winapi.PIPE_WAIT, _winapi.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL - ) + ) def accept(self): self._handle_queue.append(self._new_handle()) @@ -730,7 +680,7 @@ def accept(self): # written data and then disconnected -- see Issue 14725. else: try: - _winapi.WaitForMultipleObjects( + res = _winapi.WaitForMultipleObjects( [ov.event], False, INFINITE) except: ov.cancel() @@ -747,8 +697,7 @@ def _finalize_pipe_listener(queue, address): for handle in queue: _winapi.CloseHandle(handle) - def PipeClient(address, _ignore=(_winapi.ERROR_SEM_TIMEOUT, - _winapi.ERROR_PIPE_BUSY)): + def PipeClient(address): ''' Return a connection object connected to the pipe given by `address` ''' @@ -760,9 +709,10 @@ def PipeClient(address, _ignore=(_winapi.ERROR_SEM_TIMEOUT, address, _winapi.GENERIC_READ | _winapi.GENERIC_WRITE, 0, _winapi.NULL, _winapi.OPEN_EXISTING, _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL - ) + ) except OSError as e: - if e.winerror not in _ignore or _check_timeout(t): + if e.winerror not in (_winapi.ERROR_SEM_TIMEOUT, + _winapi.ERROR_PIPE_BUSY) or _check_timeout(t): raise else: break @@ -771,7 +721,7 @@ def PipeClient(address, _ignore=(_winapi.ERROR_SEM_TIMEOUT, _winapi.SetNamedPipeHandleState( h, _winapi.PIPE_READMODE_MESSAGE, None, None - ) + ) return PipeConnection(h) # @@ -784,10 +734,11 @@ def PipeClient(address, _ignore=(_winapi.ERROR_SEM_TIMEOUT, WELCOME = b'#WELCOME#' FAILURE = b'#FAILURE#' - def deliver_challenge(connection, authkey): import hmac - assert isinstance(authkey, bytes) + if not isinstance(authkey, bytes): + raise ValueError( + "Authkey must be bytes, not {0!s}".format(type(authkey))) message = os.urandom(MESSAGE_LENGTH) connection.send_bytes(CHALLENGE + message) digest = hmac.new(authkey, message, 'md5').digest() @@ -798,10 +749,11 @@ def deliver_challenge(connection, authkey): connection.send_bytes(FAILURE) raise AuthenticationError('digest received was wrong') - def answer_challenge(connection, authkey): import hmac - assert isinstance(authkey, bytes) + if not isinstance(authkey, bytes): + raise ValueError( + "Authkey must be bytes, not {0!s}".format(type(authkey))) message = connection.recv_bytes(256) # reject large message assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message message = message[len(CHALLENGE):] @@ -815,9 +767,7 @@ def answer_challenge(connection, authkey): # Support for using xmlrpclib for serialization # - class ConnectionWrapper(object): - def __init__(self, conn, dumps, loads): self._conn = conn self._dumps = dumps @@ -825,38 +775,30 @@ def __init__(self, conn, dumps, loads): for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'): obj = getattr(conn, attr) setattr(self, attr, obj) - def send(self, obj): s = self._dumps(obj) self._conn.send_bytes(s) - def recv(self): s = self._conn.recv_bytes() return self._loads(s) - def _xml_dumps(obj): - o = xmlrpclib.dumps((obj, ), None, None, None, 1) # noqa - return o.encode('utf-8') - + return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf-8') def _xml_loads(s): - (obj,), method = xmlrpclib.loads(s.decode('utf-8')) # noqa + (obj,), method = xmlrpclib.loads(s.decode('utf-8')) return obj - class XmlListener(Listener): - def accept(self): global xmlrpclib - import xmlrpc.client as xmlrpclib # noqa + import xmlrpc.client as xmlrpclib obj = Listener.accept(self) return ConnectionWrapper(obj, _xml_dumps, _xml_loads) - def XmlClient(*args, **kwds): global xmlrpclib - import xmlrpc.client as xmlrpclib # noqa + import xmlrpc.client as xmlrpclib return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads) # @@ -866,8 +808,8 @@ def XmlClient(*args, **kwds): if sys.platform == 'win32': def _exhaustive_wait(handles, timeout): - # Return ALL handles which are currently signaled. (Only - # returning the first signaled might create starvation issues.) + # Return ALL handles which are currently signalled. (Only + # returning the first signalled might create starvation issues.) L = list(handles) ready = [] while L: @@ -881,7 +823,7 @@ def _exhaustive_wait(handles, timeout): else: raise RuntimeError('Should not get here') ready.append(L[res]) - L = L[res + 1:] + L = L[res+1:] timeout = 0 return ready @@ -917,7 +859,7 @@ def wait(object_list, timeout=None): try: ov, err = _winapi.ReadFile(fileno(), 0, True) except OSError as e: - err = e.winerror + ov, err = None, e.winerror if err not in _ready_errors: raise if err == _winapi.ERROR_IO_PENDING: @@ -926,7 +868,16 @@ def wait(object_list, timeout=None): else: # If o.fileno() is an overlapped pipe handle and # err == 0 then there is a zero length message - # in the pipe, but it HAS NOT been consumed. + # in the pipe, but it HAS NOT been consumed... + if ov and sys.getwindowsversion()[:2] >= (6, 2): + # ... except on Windows 8 and later, where + # the message HAS been consumed. + try: + _, err = ov.GetOverlappedResult(False) + except OSError as e: + err = e.winerror + if not err and hasattr(o, '_got_empty_message'): + o._got_empty_message = True ready_objects.add(o) timeout = 0 @@ -954,51 +905,42 @@ def wait(object_list, timeout=None): o._got_empty_message = True ready_objects.update(waithandle_to_obj[h] for h in ready_handles) - return [p for p in object_list if p in ready_objects] + return [o for o in object_list if o in ready_objects] else: - if hasattr(select, 'poll'): - def _poll(fds, timeout): - if timeout is not None: - timeout = int(timeout * 1000) # timeout is in milliseconds - fd_map = {} - pollster = select.poll() - for fd in fds: - pollster.register(fd, select.POLLIN) - if hasattr(fd, 'fileno'): - fd_map[fd.fileno()] = fd - else: - fd_map[fd] = fd - ls = [] - for fd, event in pollster.poll(timeout): - if event & select.POLLNVAL: - raise ValueError('invalid file descriptor %i' % fd) - ls.append(fd_map[fd]) - return ls + import selectors + + # poll/select have the advantage of not requiring any extra file + # descriptor, contrarily to epoll/kqueue (also, they require a single + # syscall). + if hasattr(selectors, 'PollSelector'): + _WaitSelector = selectors.PollSelector else: - def _poll(fds, timeout): # noqa - return select.select(fds, [], [], timeout)[0] + _WaitSelector = selectors.SelectSelector - def wait(object_list, timeout=None): # noqa + def wait(object_list, timeout=None): ''' Wait till an object in object_list is ready/readable. Returns list of those objects in object_list which are ready/readable. ''' - if timeout is not None: - if timeout <= 0: - return _poll(object_list, 0) - else: - deadline = monotonic() + timeout - while True: - try: - return _poll(object_list, timeout) - except (OSError, IOError, socket.error) as e: - if e.errno != errno.EINTR: - raise + with _WaitSelector() as selector: + for obj in object_list: + selector.register(obj, selectors.EVENT_READ) + if timeout is not None: - timeout = deadline - monotonic() + deadline = time.monotonic() + timeout + + while True: + ready = selector.select(timeout) + if ready: + return [key.fileobj for (key, events) in ready] + else: + if timeout is not None: + timeout = deadline - time.monotonic() + if timeout < 0: + return ready # # Make connection and socket objects sharable if possible @@ -1011,10 +953,9 @@ def reduce_connection(conn): from . import resource_sharer ds = resource_sharer.DupSocket(s) return rebuild_connection, (ds, conn.readable, conn.writable) - def rebuild_connection(ds, readable, writable): sock = ds.detach() - return Connection(detach(sock), readable, writable) + return Connection(sock.detach(), readable, writable) reduction.register(Connection, reduce_connection) def reduce_pipe_connection(conn): @@ -1022,16 +963,16 @@ def reduce_pipe_connection(conn): (_winapi.FILE_GENERIC_WRITE if conn.writable else 0)) dh = reduction.DupHandle(conn.fileno(), access) return rebuild_pipe_connection, (dh, conn.readable, conn.writable) - def rebuild_pipe_connection(dh, readable, writable): - return PipeConnection(detach(dh), readable, writable) + handle = dh.detach() + return PipeConnection(handle, readable, writable) reduction.register(PipeConnection, reduce_pipe_connection) else: def reduce_connection(conn): df = reduction.DupFd(conn.fileno()) return rebuild_connection, (df, conn.readable, conn.writable) - def rebuild_connection(df, readable, writable): - return Connection(detach(df), readable, writable) + fd = df.detach() + return Connection(fd, readable, writable) reduction.register(Connection, reduce_connection) diff --git a/billiard/context.py b/billiard/context.py index 3bfa8dc8..8d0525d5 100644 --- a/billiard/context.py +++ b/billiard/context.py @@ -1,38 +1,30 @@ -from __future__ import absolute_import - import os import sys import threading -import warnings from . import process +from . import reduction -__all__ = [] # things are copied from here to __init__.py - - -W_NO_EXECV = """\ -force_execv is not supported as the billiard C extension \ -is not installed\ -""" - +__all__ = () # # Exceptions # -from .exceptions import ( # noqa - ProcessError, - BufferTooShort, - TimeoutError, - AuthenticationError, - TimeLimitExceeded, - SoftTimeLimitExceeded, - WorkerLostError, -) +class ProcessError(Exception): + pass + +class BufferTooShort(ProcessError): + pass + +class TimeoutError(ProcessError): + pass +class AuthenticationError(ProcessError): + pass # -# Base type for contexts +# Base type for contexts. Bound methods of an instance of this type are included in __all__ of __init__.py # class BaseContext(object): @@ -41,47 +33,18 @@ class BaseContext(object): BufferTooShort = BufferTooShort TimeoutError = TimeoutError AuthenticationError = AuthenticationError - TimeLimitExceeded = TimeLimitExceeded - SoftTimeLimitExceeded = SoftTimeLimitExceeded - WorkerLostError = WorkerLostError current_process = staticmethod(process.current_process) + parent_process = staticmethod(process.parent_process) active_children = staticmethod(process.active_children) - if hasattr(os, 'cpu_count'): - def cpu_count(self): - '''Returns the number of CPUs in the system''' - num = os.cpu_count() - if num is None: - raise NotImplementedError('cannot determine number of cpus') - else: - return num - else: - def cpu_count(self): # noqa - if sys.platform == 'win32': - try: - num = int(os.environ['NUMBER_OF_PROCESSORS']) - except (ValueError, KeyError): - num = 0 - elif 'bsd' in sys.platform or sys.platform == 'darwin': - comm = '/sbin/sysctl -n hw.ncpu' - if sys.platform == 'darwin': - comm = '/usr' + comm - try: - with os.popen(comm) as p: - num = int(p.read()) - except ValueError: - num = 0 - else: - try: - num = os.sysconf('SC_NPROCESSORS_ONLN') - except (ValueError, OSError, AttributeError): - num = 0 - - if num >= 1: - return num - else: - raise NotImplementedError('cannot determine number of cpus') + def cpu_count(self): + '''Returns the number of CPUs in the system''' + num = os.cpu_count() + if num is None: + raise NotImplementedError('cannot determine number of cpus') + else: + return num def Manager(self): '''Returns a manager associated with a running server process @@ -94,10 +57,10 @@ def Manager(self): m.start() return m - def Pipe(self, duplex=True, rnonblock=False, wnonblock=False): + def Pipe(self, duplex=True): '''Returns two connection object connected by a pipe''' from .connection import Pipe - return Pipe(duplex, rnonblock, wnonblock) + return Pipe(duplex) def Lock(self): '''Returns a non-recursive lock object''' @@ -150,18 +113,10 @@ def SimpleQueue(self): return SimpleQueue(ctx=self.get_context()) def Pool(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, timeout=None, soft_timeout=None, - lost_worker_timeout=None, max_restarts=None, - max_restart_freq=1, on_process_up=None, on_process_down=None, - on_timeout_set=None, on_timeout_cancel=None, threads=True, - semaphore=None, putlocks=False, allow_restart=False): + maxtasksperchild=None): '''Returns a process pool object''' from .pool import Pool return Pool(processes, initializer, initargs, maxtasksperchild, - timeout, soft_timeout, lost_worker_timeout, - max_restarts, max_restart_freq, on_process_up, - on_process_down, on_timeout_set, on_timeout_cancel, - threads, semaphore, putlocks, allow_restart, context=self.get_context()) def RawValue(self, typecode_or_type, *args): @@ -174,17 +129,15 @@ def RawArray(self, typecode_or_type, size_or_initializer): from .sharedctypes import RawArray return RawArray(typecode_or_type, size_or_initializer) - def Value(self, typecode_or_type, *args, **kwargs): + def Value(self, typecode_or_type, *args, lock=True): '''Returns a synchronized shared object''' from .sharedctypes import Value - lock = kwargs.get('lock', True) return Value(typecode_or_type, *args, lock=lock, ctx=self.get_context()) - def Array(self, typecode_or_type, size_or_initializer, *args, **kwargs): + def Array(self, typecode_or_type, size_or_initializer, *, lock=True): '''Returns a synchronized shared array''' from .sharedctypes import Array - lock = kwargs.get('lock', True) return Array(typecode_or_type, size_or_initializer, lock=lock, ctx=self.get_context()) @@ -214,7 +167,7 @@ def allow_connection_pickling(self): ''' # This is undocumented. In previous versions of multiprocessing # its only effect was to make socket objects inheritable on Windows. - from . import connection # noqa + from . import connection def set_executable(self, executable): '''Sets the path to a python.exe or pythonw.exe binary used to run @@ -237,28 +190,25 @@ def get_context(self, method=None): try: ctx = _concrete_contexts[method] except KeyError: - raise ValueError('cannot find context for %r' % method) + raise ValueError('cannot find context for %r' % method) from None ctx._check_available() return ctx def get_start_method(self, allow_none=False): return self._name - def set_start_method(self, method=None): + def set_start_method(self, method, force=False): raise ValueError('cannot set start method of concrete context') - def forking_is_enabled(self): - # XXX for compatibility with billiard <3.4 - return (self.get_start_method() or 'fork') == 'fork' + @property + def reducer(self): + '''Controls how objects will be reduced to a form that can be + shared with other processes.''' + return globals().get('reduction') - def forking_enable(self, value): - # XXX for compatibility with billiard <3.4 - if not value: - from ._ext import supports_exec - if supports_exec: - self.set_start_method('spawn', force=True) - else: - warnings.warn(RuntimeWarning(W_NO_EXECV)) + @reducer.setter + def reducer(self, reduction): + globals()['reduction'] = reduction def _check_available(self): pass @@ -267,15 +217,12 @@ def _check_available(self): # Type of default context -- underlying context can be set at most once # - class Process(process.BaseProcess): _start_method = None - @staticmethod def _Popen(process_obj): return _default_context.get_context().Process._Popen(process_obj) - class DefaultContext(BaseContext): Process = Process @@ -289,7 +236,7 @@ def get_context(self, method=None): self._actual_context = self._default_context return self._actual_context else: - return super(DefaultContext, self).get_context(method) + return super().get_context(method) def set_start_method(self, method, force=False): if self._actual_context is not None and not force: @@ -310,13 +257,11 @@ def get_all_start_methods(self): if sys.platform == 'win32': return ['spawn'] else: - from . import reduction + methods = ['spawn', 'fork'] if sys.platform == 'darwin' else ['fork', 'spawn'] if reduction.HAVE_SEND_HANDLE: - return ['fork', 'spawn', 'forkserver'] - else: - return ['fork', 'spawn'] + methods.append('forkserver') + return methods -DefaultContext.__all__ = list(x for x in dir(DefaultContext) if x[0] != '_') # # Context types for fixed start method @@ -326,7 +271,6 @@ def get_all_start_methods(self): class ForkProcess(process.BaseProcess): _start_method = 'fork' - @staticmethod def _Popen(process_obj): from .popen_fork import Popen @@ -334,7 +278,6 @@ def _Popen(process_obj): class SpawnProcess(process.BaseProcess): _start_method = 'spawn' - @staticmethod def _Popen(process_obj): from .popen_spawn_posix import Popen @@ -342,7 +285,6 @@ def _Popen(process_obj): class ForkServerProcess(process.BaseProcess): _start_method = 'forkserver' - @staticmethod def _Popen(process_obj): from .popen_forkserver import Popen @@ -359,9 +301,7 @@ class SpawnContext(BaseContext): class ForkServerContext(BaseContext): _name = 'forkserver' Process = ForkServerProcess - def _check_available(self): - from . import reduction if not reduction.HAVE_SEND_HANDLE: raise ValueError('forkserver start method not available') @@ -370,13 +310,17 @@ def _check_available(self): 'spawn': SpawnContext(), 'forkserver': ForkServerContext(), } - _default_context = DefaultContext(_concrete_contexts['fork']) + if sys.platform == 'darwin': + # bpo-33725: running arbitrary code after fork() is no longer reliable + # on macOS since macOS 10.14 (Mojave). Use spawn by default instead. + _default_context = DefaultContext(_concrete_contexts['spawn']) + else: + _default_context = DefaultContext(_concrete_contexts['fork']) else: class SpawnProcess(process.BaseProcess): _start_method = 'spawn' - @staticmethod def _Popen(process_obj): from .popen_spawn_win32 import Popen @@ -395,7 +339,6 @@ class SpawnContext(BaseContext): # Force the start method # - def _force_start_method(method): _default_context._actual_context = _concrete_contexts[method] @@ -405,18 +348,15 @@ def _force_start_method(method): _tls = threading.local() - def get_spawning_popen(): return getattr(_tls, 'spawning_popen', None) - def set_spawning_popen(popen): _tls.spawning_popen = popen - def assert_spawning(obj): if get_spawning_popen() is None: raise RuntimeError( '%s objects should only be shared between processes' ' through inheritance' % type(obj).__name__ - ) + ) diff --git a/billiard/dummy/__init__.py b/billiard/dummy/__init__.py index 032f5f15..6a146860 100644 --- a/billiard/dummy/__init__.py +++ b/billiard/dummy/__init__.py @@ -4,34 +4,14 @@ # multiprocessing/dummy/__init__.py # # Copyright (c) 2006-2008, R Oudkerk -# All rights reserved. +# Licensed to PSF under a Contributor Agreement. # -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# 1. Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# 3. Neither the name of author nor the names of any contributors may be -# used to endorse or promote products derived from this software -# without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS -# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -# SUCH DAMAGE. -# -from __future__ import absolute_import + +__all__ = [ + 'Process', 'current_process', 'active_children', 'freeze_support', + 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', + 'Event', 'Barrier', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' + ] # # Imports @@ -42,19 +22,14 @@ import weakref import array +from .connection import Pipe from threading import Lock, RLock, Semaphore, BoundedSemaphore -from threading import Event - -from billiard.five import Queue - -from billiard.connection import Pipe - -__all__ = [ - 'Process', 'current_process', 'active_children', 'freeze_support', - 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', - 'Event', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' -] +from threading import Event, Condition, Barrier +from queue import Queue +# +# +# class DummyProcess(threading.Thread): @@ -66,7 +41,10 @@ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): self._parent = current_process() def start(self): - assert self._parent is current_process() + if self._parent is not current_process(): + raise RuntimeError( + "Parent is {0!r} but current_process is {1!r}".format( + self._parent, current_process())) self._start_called = True if hasattr(self._parent, '_children'): self._parent._children[self] = None @@ -79,25 +57,14 @@ def exitcode(self): else: return None - -try: - _Condition = threading._Condition -except AttributeError: # Py3 - _Condition = threading.Condition # noqa - - -class Condition(_Condition): - if sys.version_info[0] == 3: - notify_all = _Condition.notifyAll - else: - notify_all = _Condition.notifyAll.__func__ - +# +# +# Process = DummyProcess current_process = threading.current_thread current_process()._children = weakref.WeakKeyDictionary() - def active_children(): children = current_process()._children for p in list(children): @@ -105,16 +72,16 @@ def active_children(): children.pop(p, None) return list(children) - def freeze_support(): pass +# +# +# class Namespace(object): - - def __init__(self, **kwds): + def __init__(self, /, **kwds): self.__dict__.update(kwds) - def __repr__(self): items = list(self.__dict__.items()) temp = [] @@ -122,46 +89,38 @@ def __repr__(self): if not name.startswith('_'): temp.append('%s=%r' % (name, value)) temp.sort() - return '%s(%s)' % (self.__class__.__name__, str.join(', ', temp)) - + return '%s(%s)' % (self.__class__.__name__, ', '.join(temp)) dict = dict list = list - def Array(typecode, sequence, lock=True): return array.array(typecode, sequence) - class Value(object): - def __init__(self, typecode, value, lock=True): self._typecode = typecode self._value = value - def _get(self): + @property + def value(self): return self._value - def _set(self, value): + @value.setter + def value(self, value): self._value = value - value = property(_get, _set) def __repr__(self): - return '<%r(%r, %r)>' % (type(self).__name__, - self._typecode, self._value) - + return '<%s(%r, %r)>'%(type(self).__name__,self._typecode,self._value) def Manager(): return sys.modules[__name__] - def shutdown(): pass - def Pool(processes=None, initializer=None, initargs=()): - from billiard.pool import ThreadPool + from ..pool import ThreadPool return ThreadPool(processes, initializer, initargs) - JoinableQueue = Queue diff --git a/billiard/dummy/connection.py b/billiard/dummy/connection.py index 6bf6b9d0..f0ce320f 100644 --- a/billiard/dummy/connection.py +++ b/billiard/dummy/connection.py @@ -4,38 +4,13 @@ # multiprocessing/dummy/connection.py # # Copyright (c) 2006-2008, R Oudkerk -# All rights reserved. +# Licensed to PSF under a Contributor Agreement. # -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# 1. Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# 3. Neither the name of author nor the names of any contributors may be -# used to endorse or promote products derived from this software -# without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS -# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -# SUCH DAMAGE. -# -from __future__ import absolute_import -from billiard.five import Queue +__all__ = [ 'Client', 'Listener', 'Pipe' ] + +from queue import Queue -__all__ = ['Client', 'Listener', 'Pipe'] families = [None] @@ -51,12 +26,14 @@ def accept(self): def close(self): self._backlog_queue = None - address = property(lambda self: self._backlog_queue) + @property + def address(self): + return self._backlog_queue def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__(self, exc_type, exc_value, exc_tb): self.close() @@ -84,10 +61,15 @@ def poll(self, timeout=0.0): return True if timeout <= 0.0: return False - self._in.not_empty.acquire() - self._in.not_empty.wait(timeout) - self._in.not_empty.release() + with self._in.not_empty: + self._in.not_empty.wait(timeout) return self._in.qsize() > 0 def close(self): pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() diff --git a/billiard/forkserver.py b/billiard/forkserver.py index a963edbc..22a911a7 100644 --- a/billiard/forkserver.py +++ b/billiard/forkserver.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, print_function - import errno import os import selectors @@ -8,16 +6,15 @@ import struct import sys import threading +import warnings from . import connection from . import process -from . import reduction -from . import semaphore_tracker +from .context import reduction +from . import resource_tracker from . import spawn from . import util -from .compat import spawnv_passfds - __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', 'set_forkserver_preload'] @@ -26,22 +23,42 @@ # MAXFDS_TO_SEND = 256 -UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t +SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t # # Forkserver class # - class ForkServer(object): def __init__(self): self._forkserver_address = None self._forkserver_alive_fd = None + self._forkserver_pid = None self._inherited_fds = None self._lock = threading.Lock() self._preload_modules = ['__main__'] + def _stop(self): + # Method used by unit tests to stop the server + with self._lock: + self._stop_unlocked() + + def _stop_unlocked(self): + if self._forkserver_pid is None: + return + + # close the "alive" file descriptor asks the server to stop + os.close(self._forkserver_alive_fd) + self._forkserver_alive_fd = None + + os.waitpid(self._forkserver_pid, 0) + self._forkserver_pid = None + + if not util.is_abstract_socket_namespace(self._forkserver_address): + os.unlink(self._forkserver_address) + self._forkserver_address = None + def set_forkserver_preload(self, modules_names): '''Set list of module names to try to load in forkserver process.''' if not all(type(mod) is str for mod in self._preload_modules): @@ -72,7 +89,7 @@ def connect_to_new_process(self, fds): parent_r, child_w = os.pipe() child_r, parent_w = os.pipe() allfds = [child_r, child_w, self._forkserver_alive_fd, - semaphore_tracker.getfd()] + resource_tracker.getfd()] allfds += fds try: reduction.sendfds(client, allfds) @@ -93,26 +110,34 @@ def ensure_running(self): ensure_running() will do nothing. ''' with self._lock: - semaphore_tracker.ensure_running() - if self._forkserver_alive_fd is not None: - return - - cmd = ('from billiard.forkserver import main; ' + + resource_tracker.ensure_running() + if self._forkserver_pid is not None: + # forkserver was launched before, is it still running? + pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG) + if not pid: + # still alive + return + # dead, launch it again + os.close(self._forkserver_alive_fd) + self._forkserver_address = None + self._forkserver_alive_fd = None + self._forkserver_pid = None + + cmd = ('from multiprocessing.forkserver import main; ' + 'main(%d, %d, %r, **%r)') if self._preload_modules: desired_keys = {'main_path', 'sys_path'} data = spawn.get_preparation_data('ignore') - data = { - x: y for (x, y) in data.items() if x in desired_keys - } + data = {x: y for x, y in data.items() if x in desired_keys} else: data = {} with socket.socket(socket.AF_UNIX) as listener: address = connection.arbitrary_address('AF_UNIX') listener.bind(address) - os.chmod(address, 0o600) + if not util.is_abstract_socket_namespace(address): + os.chmod(address, 0o600) listener.listen() # all client processes own the write end of the "alive" pipe; @@ -125,7 +150,7 @@ def ensure_running(self): exe = spawn.get_executable() args = [exe] + util._args_from_interpreter_flags() args += ['-c', cmd] - spawnv_passfds(exe, args, fds_to_pass) + pid = util.spawnv_passfds(exe, args, fds_to_pass) except: os.close(alive_w) raise @@ -133,12 +158,12 @@ def ensure_running(self): os.close(alive_r) self._forkserver_address = address self._forkserver_alive_fd = alive_w + self._forkserver_pid = pid # # # - def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): '''Run forkserver.''' if preload: @@ -154,21 +179,38 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): except ImportError: pass - # close sys.stdin - if sys.stdin is not None: - try: - sys.stdin.close() - sys.stdin = open(os.devnull) - except (OSError, ValueError): - pass + util._close_stdin() + + sig_r, sig_w = os.pipe() + os.set_blocking(sig_r, False) + os.set_blocking(sig_w, False) + + def sigchld_handler(*_unused): + # Dummy signal handler, doesn't do anything + pass + + handlers = { + # unblocking SIGCHLD allows the wakeup fd to notify our event loop + signal.SIGCHLD: sigchld_handler, + # protect the process from ^C + signal.SIGINT: signal.SIG_IGN, + } + old_handlers = {sig: signal.signal(sig, val) + for (sig, val) in handlers.items()} + + # calling os.write() in the Python signal handler is racy + signal.set_wakeup_fd(sig_w) + + # map child pids to client fds + pid_to_fd = {} - # ignoring SIGCHLD means no need to reap zombie processes - handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ - selectors.DefaultSelector() as selector: + selectors.DefaultSelector() as selector: _forkserver._forkserver_address = listener.getsockname() + selector.register(listener, selectors.EVENT_READ) selector.register(alive_r, selectors.EVENT_READ) + selector.register(sig_r, selectors.EVENT_READ) while True: try: @@ -179,76 +221,116 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): if alive_r in rfds: # EOF because no more client processes left - assert os.read(alive_r, 1) == b'' + assert os.read(alive_r, 1) == b'', "Not at EOF?" raise SystemExit - assert listener in rfds - with listener.accept()[0] as s: - code = 1 - if os.fork() == 0: + if sig_r in rfds: + # Got SIGCHLD + os.read(sig_r, 65536) # exhaust + while True: + # Scan for child processes try: - _serve_one(s, listener, alive_r, handler) - except Exception: - sys.excepthook(*sys.exc_info()) - sys.stderr.flush() - finally: - os._exit(code) + pid, sts = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + break + child_w = pid_to_fd.pop(pid, None) + if child_w is not None: + returncode = os.waitstatus_to_exitcode(sts) + + # Send exit code to client process + try: + write_signed(child_w, returncode) + except BrokenPipeError: + # client vanished + pass + os.close(child_w) + else: + # This shouldn't happen really + warnings.warn('forkserver: waitpid returned ' + 'unexpected pid %d' % pid) + + if listener in rfds: + # Incoming fork request + with listener.accept()[0] as s: + # Receive fds from client + fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) + if len(fds) > MAXFDS_TO_SEND: + raise RuntimeError( + "Too many ({0:n}) fds to send".format( + len(fds))) + child_r, child_w, *fds = fds + s.close() + pid = os.fork() + if pid == 0: + # Child + code = 1 + try: + listener.close() + selector.close() + unused_fds = [alive_r, child_w, sig_r, sig_w] + unused_fds.extend(pid_to_fd.values()) + code = _serve_one(child_r, fds, + unused_fds, + old_handlers) + except Exception: + sys.excepthook(*sys.exc_info()) + sys.stderr.flush() + finally: + os._exit(code) + else: + # Send pid to client process + try: + write_signed(child_w, pid) + except BrokenPipeError: + # client vanished + pass + pid_to_fd[pid] = child_w + os.close(child_r) + for fd in fds: + os.close(fd) + except OSError as e: if e.errno != errno.ECONNABORTED: raise -def __unpack_fds(child_r, child_w, alive, stfd, *inherited): - return child_r, child_w, alive, stfd, inherited +def _serve_one(child_r, fds, unused_fds, handlers): + # close unnecessary stuff and reset signal handlers + signal.set_wakeup_fd(-1) + for sig, val in handlers.items(): + signal.signal(sig, val) + for fd in unused_fds: + os.close(fd) + (_forkserver._forkserver_alive_fd, + resource_tracker._resource_tracker._fd, + *_forkserver._inherited_fds) = fds -def _serve_one(s, listener, alive_r, handler): - # close unnecessary stuff and reset SIGCHLD handler - listener.close() - os.close(alive_r) - signal.signal(signal.SIGCHLD, handler) + # Run process object received over pipe + parent_sentinel = os.dup(child_r) + code = spawn._main(child_r, parent_sentinel) - # receive fds from parent process - fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) - s.close() - assert len(fds) <= MAXFDS_TO_SEND + return code - (child_r, child_w, _forkserver._forkserver_alive_fd, - stfd, _forkserver._inherited_fds) = __unpack_fds(*fds) - semaphore_tracker._semaphore_tracker._fd = stfd - - # send pid to client processes - write_unsigned(child_w, os.getpid()) - - # reseed random number generator - if 'random' in sys.modules: - import random - random.seed() - - # run process object received over pipe - code = spawn._main(child_r) - - # write the exit code to the pipe - write_unsigned(child_w, code) # -# Read and write unsigned numbers +# Read and write signed numbers # - -def read_unsigned(fd): +def read_signed(fd): data = b'' - length = UNSIGNED_STRUCT.size + length = SIGNED_STRUCT.size while len(data) < length: s = os.read(fd, length - len(data)) if not s: raise EOFError('unexpected EOF') data += s - return UNSIGNED_STRUCT.unpack(data)[0] - + return SIGNED_STRUCT.unpack(data)[0] -def write_unsigned(fd, n): - msg = UNSIGNED_STRUCT.pack(n) +def write_signed(fd, n): + msg = SIGNED_STRUCT.pack(n) while msg: nbytes = os.write(fd, msg) if nbytes == 0: diff --git a/billiard/heap.py b/billiard/heap.py index b7581cec..6217dfe1 100644 --- a/billiard/heap.py +++ b/billiard/heap.py @@ -6,34 +6,32 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import import bisect -import errno -import io +from collections import defaultdict import mmap import os import sys -import threading import tempfile +import threading -from . import context -from . import reduction +from .context import reduction, assert_spawning from . import util -from ._ext import _billiard, win32 - __all__ = ['BufferWrapper'] -PY3 = sys.version_info[0] == 3 - # # Inheritable class which wraps an mmap, and from which blocks can be allocated # if sys.platform == 'win32': + import _winapi + class Arena(object): + """ + A shared memory area backed by anonymous memory (Windows). + """ _rand = tempfile._RandomNameSequence() @@ -42,70 +40,66 @@ def __init__(self, size): for i in range(100): name = 'pym-%d-%s' % (os.getpid(), next(self._rand)) buf = mmap.mmap(-1, size, tagname=name) - if win32.GetLastError() == 0: + if _winapi.GetLastError() == 0: break - # we have reopened a preexisting map + # We have reopened a preexisting mmap. buf.close() else: - exc = IOError('Cannot find name for new mmap') - exc.errno = errno.EEXIST - raise exc + raise FileExistsError('Cannot find name for new mmap') self.name = name self.buffer = buf self._state = (self.size, self.name) def __getstate__(self): - context.assert_spawning(self) + assert_spawning(self) return self._state def __setstate__(self, state): self.size, self.name = self._state = state + # Reopen existing mmap self.buffer = mmap.mmap(-1, self.size, tagname=self.name) # XXX Temporarily preventing buildbot failures while determining - # XXX the correct long-term fix. See issue #23060 - # assert win32.GetLastError() == win32.ERROR_ALREADY_EXISTS + # XXX the correct long-term fix. See issue 23060 + #assert _winapi.GetLastError() == _winapi.ERROR_ALREADY_EXISTS else: class Arena(object): + """ + A shared memory area backed by a temporary file (POSIX). + """ + + if sys.platform == 'linux': + _dir_candidates = ['/dev/shm'] + else: + _dir_candidates = [] def __init__(self, size, fd=-1): self.size = size self.fd = fd if fd == -1: - if PY3: - self.fd, name = tempfile.mkstemp( - prefix='pym-%d-' % (os.getpid(),), - dir=util.get_temp_dir(), - ) - - os.unlink(name) - util.Finalize(self, os.close, (self.fd,)) - with io.open(self.fd, 'wb', closefd=False) as f: - bs = 1024 * 1024 - if size >= bs: - zeros = b'\0' * bs - for _ in range(size // bs): - f.write(zeros) - del(zeros) - f.write(b'\0' * (size % bs)) - assert f.tell() == size - else: - name = tempfile.mktemp( - prefix='pym-%d-' % (os.getpid(),), - dir=util.get_temp_dir(), - ) - self.fd = os.open( - name, os.O_RDWR | os.O_CREAT | os.O_EXCL, 0o600, - ) - util.Finalize(self, os.close, (self.fd,)) - os.unlink(name) - os.ftruncate(self.fd, size) + # Arena is created anew (if fd != -1, it means we're coming + # from rebuild_arena() below) + self.fd, name = tempfile.mkstemp( + prefix='pym-%d-'%os.getpid(), + dir=self._choose_dir(size)) + os.unlink(name) + util.Finalize(self, os.close, (self.fd,)) + os.ftruncate(self.fd, size) self.buffer = mmap.mmap(self.fd, self.size) + def _choose_dir(self, size): + # Choose a non-storage backed directory if possible, + # to improve performance + for d in self._dir_candidates: + st = os.statvfs(d) + if st.f_bavail * st.f_frsize >= size: # enough free space? + return d + return util.get_temp_dir() + def reduce_arena(a): if a.fd == -1: - raise ValueError('Arena is unpicklable because' + raise ValueError('Arena is unpicklable because ' 'forking was enabled when it was created') return rebuild_arena, (a.size, reduction.DupFd(a.fd)) @@ -118,40 +112,84 @@ def rebuild_arena(size, dupfd): # Class allowing allocation of chunks of memory from arenas # - class Heap(object): + # Minimum malloc() alignment _alignment = 8 + _DISCARD_FREE_SPACE_LARGER_THAN = 4 * 1024 ** 2 # 4 MB + _DOUBLE_ARENA_SIZE_UNTIL = 4 * 1024 ** 2 + def __init__(self, size=mmap.PAGESIZE): self._lastpid = os.getpid() self._lock = threading.Lock() + # Current arena allocation size self._size = size + # A sorted list of available block sizes in arenas self._lengths = [] + + # Free block management: + # - map each block size to a list of `(Arena, start, stop)` blocks self._len_to_seq = {} + # - map `(Arena, start)` tuple to the `(Arena, start, stop)` block + # starting at that offset self._start_to_block = {} + # - map `(Arena, stop)` tuple to the `(Arena, start, stop)` block + # ending at that offset self._stop_to_block = {} - self._allocated_blocks = set() + + # Map arenas to their `(Arena, start, stop)` blocks in use + self._allocated_blocks = defaultdict(set) self._arenas = [] - # list of pending blocks to free - see free() comment below + + # List of pending blocks to free - see comment in free() below self._pending_free_blocks = [] + # Statistics + self._n_mallocs = 0 + self._n_frees = 0 + @staticmethod def _roundup(n, alignment): # alignment must be a power of 2 mask = alignment - 1 return (n + mask) & ~mask + def _new_arena(self, size): + # Create a new arena with at least the given *size* + length = self._roundup(max(self._size, size), mmap.PAGESIZE) + # We carve larger and larger arenas, for efficiency, until we + # reach a large-ish size (roughly L3 cache-sized) + if self._size < self._DOUBLE_ARENA_SIZE_UNTIL: + self._size *= 2 + util.info('allocating a new mmap of length %d', length) + arena = Arena(length) + self._arenas.append(arena) + return (arena, 0, length) + + def _discard_arena(self, arena): + # Possibly delete the given (unused) arena + length = arena.size + # Reusing an existing arena is faster than creating a new one, so + # we only reclaim space if it's large enough. + if length < self._DISCARD_FREE_SPACE_LARGER_THAN: + return + blocks = self._allocated_blocks.pop(arena) + assert not blocks + del self._start_to_block[(arena, 0)] + del self._stop_to_block[(arena, length)] + self._arenas.remove(arena) + seq = self._len_to_seq[length] + seq.remove((arena, 0, length)) + if not seq: + del self._len_to_seq[length] + self._lengths.remove(length) + def _malloc(self, size): # returns a large enough block -- it might be much larger i = bisect.bisect_left(self._lengths, size) if i == len(self._lengths): - length = self._roundup(max(self._size, size), mmap.PAGESIZE) - self._size *= 2 - util.info('allocating a new mmap of length %d', length) - arena = Arena(length) - self._arenas.append(arena) - return (arena, 0, length) + return self._new_arena(size) else: length = self._lengths[i] seq = self._len_to_seq[length] @@ -164,8 +202,8 @@ def _malloc(self, size): del self._stop_to_block[(arena, stop)] return block - def _free(self, block): - # free location and try to merge with neighbours + def _add_free_block(self, block): + # make block available and try to merge with its neighbours in the arena (arena, start, stop) = block try: @@ -209,15 +247,23 @@ def _absorb(self, block): return start, stop + def _remove_allocated_block(self, block): + arena, start, stop = block + blocks = self._allocated_blocks[arena] + blocks.remove((start, stop)) + if not blocks: + # Arena is entirely free, discard it from this process + self._discard_arena(arena) + def _free_pending_blocks(self): - # Free all the blocks in the pending list - called with the lock held - while 1: + # Free all the blocks in the pending list - called with the lock held. + while True: try: block = self._pending_free_blocks.pop() except IndexError: break - self._allocated_blocks.remove(block) - self._free(block) + self._add_free_block(block) + self._remove_allocated_block(block) def free(self, block): # free a block returned by malloc() @@ -228,9 +274,11 @@ def free(self, block): # immediately, the block is added to a list of blocks to be freed # synchronously sometimes later from malloc() or free(), by calling # _free_pending_blocks() (appending and retrieving from a list is not - # strictly thread-safe but under cPython it's atomic - # thanks to the GIL). - assert os.getpid() == self._lastpid + # strictly thread-safe but under CPython it's atomic thanks to the GIL). + if os.getpid() != self._lastpid: + raise ValueError( + "My pid ({0:n}) is not last pid {1:n}".format( + os.getpid(),self._lastpid)) if not self._lock.acquire(False): # can't acquire the lock right now, add the block to the list of # pending blocks to free @@ -238,52 +286,52 @@ def free(self, block): else: # we hold the lock try: + self._n_frees += 1 self._free_pending_blocks() - self._allocated_blocks.remove(block) - self._free(block) + self._add_free_block(block) + self._remove_allocated_block(block) finally: self._lock.release() def malloc(self, size): # return a block of right size (possibly rounded up) - assert 0 <= size < sys.maxsize + if size < 0: + raise ValueError("Size {0:n} out of range".format(size)) + if sys.maxsize <= size: + raise OverflowError("Size {0:n} too large".format(size)) if os.getpid() != self._lastpid: self.__init__() # reinitialize after fork with self._lock: + self._n_mallocs += 1 + # allow pending blocks to be marked available self._free_pending_blocks() size = self._roundup(max(size, 1), self._alignment) (arena, start, stop) = self._malloc(size) - new_stop = start + size - if new_stop < stop: - self._free((arena, new_stop, stop)) - block = (arena, start, new_stop) - self._allocated_blocks.add(block) - return block + real_stop = start + size + if real_stop < stop: + # if the returned block is larger than necessary, mark + # the remainder available + self._add_free_block((arena, real_stop, stop)) + self._allocated_blocks[arena].add((start, real_stop)) + return (arena, start, real_stop) # -# Class representing a chunk of an mmap -- can be inherited +# Class wrapping a block allocated out of a Heap -- can be inherited by child process # - class BufferWrapper(object): _heap = Heap() def __init__(self, size): - assert 0 <= size < sys.maxsize + if size < 0: + raise ValueError("Size {0:n} out of range".format(size)) + if sys.maxsize <= size: + raise OverflowError("Size {0:n} too large".format(size)) block = BufferWrapper._heap.malloc(size) self._state = (block, size) util.Finalize(self, BufferWrapper._heap.free, args=(block,)) - def get_address(self): - (arena, start, stop), size = self._state - address, length = _billiard.address_of_buffer(arena.buffer) - assert size <= length - return address + start - - def get_size(self): - return self._state[1] - def create_memoryview(self): (arena, start, stop), size = self._state - return memoryview(arena.buffer)[start:start + size] + return memoryview(arena.buffer)[start:start+size] diff --git a/billiard/managers.py b/billiard/managers.py index 78e24f41..0eb16c66 100644 --- a/billiard/managers.py +++ b/billiard/managers.py @@ -1,5 +1,5 @@ # -# Module providing the `SyncManager` class for dealing +# Module providing manager classes for dealing # with shared objects # # multiprocessing/managers.py @@ -7,7 +7,9 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import + +__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token', + 'SharedMemoryManager' ] # # Imports @@ -15,43 +17,40 @@ import sys import threading +import signal import array +import queue +import time +import types +import os +from os import getpid from traceback import format_exc from . import connection -from . import context +from .context import reduction, get_spawning_popen, ProcessError from . import pool from . import process -from . import reduction from . import util from . import get_context - -from .five import Queue, items, monotonic - -__all__ = ['BaseManager', 'SyncManager', 'BaseProxy', 'Token'] - -PY3 = sys.version_info[0] == 3 +try: + from . import shared_memory + HAS_SHMEM = True +except ImportError: + HAS_SHMEM = False # # Register some things for pickling # - -if PY3: - def reduce_array(a): - return array.array, (a.typecode, a.tobytes()) -else: - def reduce_array(a): # noqa - return array.array, (a.typecode, a.tostring()) +def reduce_array(a): + return array.array, (a.typecode, a.tobytes()) reduction.register(array.array, reduce_array) -view_types = [type(getattr({}, name)()) - for name in ('items', 'keys', 'values')] -if view_types[0] is not list: # only needed in Py3.0 - +view_types = [type(getattr({}, name)()) for name in ('items','keys','values')] +if view_types[0] is not list: # only needed in Py3.0 def rebuild_as_list(obj): - return list, (list(obj), ) + return list, (list(obj),) for view_type in view_types: reduction.register(view_type, rebuild_as_list) @@ -59,10 +58,9 @@ def rebuild_as_list(obj): # Type for identifying shared objects # - class Token(object): ''' - Type to uniquely indentify a shared object + Type to uniquely identify a shared object ''' __slots__ = ('typeid', 'address', 'id') @@ -77,13 +75,12 @@ def __setstate__(self, state): def __repr__(self): return '%s(typeid=%r, address=%r, id=%r)' % \ - (self.__class__.__name__, self.typeid, self.address, self.id) + (self.__class__.__name__, self.typeid, self.address, self.id) # # Function for communication with a manager's server process # - def dispatch(c, id, methodname, args=(), kwds={}): ''' Send a message to manager using connection `c` and return response @@ -94,30 +91,29 @@ def dispatch(c, id, methodname, args=(), kwds={}): return result raise convert_to_error(kind, result) - def convert_to_error(kind, result): if kind == '#ERROR': return result - elif kind == '#TRACEBACK': - assert type(result) is str - return RemoteError(result) - elif kind == '#UNSERIALIZABLE': - assert type(result) is str - return RemoteError('Unserializable message: %s\n' % result) + elif kind in ('#TRACEBACK', '#UNSERIALIZABLE'): + if not isinstance(result, str): + raise TypeError( + "Result {0!r} (kind '{1}') type is {2}, not str".format( + result, kind, type(result))) + if kind == '#UNSERIALIZABLE': + return RemoteError('Unserializable message: %s\n' % result) + else: + return RemoteError(result) else: - return ValueError('Unrecognized message type') - + return ValueError('Unrecognized message type {!r}'.format(kind)) class RemoteError(Exception): - def __str__(self): - return ('\n' + '-' * 75 + '\n' + str(self.args[0]) + '-' * 75) + return ('\n' + '-'*75 + '\n' + str(self.args[0]) + '-'*75) # # Functions for finding the method names of an object # - def all_methods(obj): ''' Return a list of names of methods of `obj` @@ -129,7 +125,6 @@ def all_methods(obj): temp.append(name) return temp - def public_methods(obj): ''' Return a list of names of methods of `obj` which do not start with '_' @@ -140,7 +135,6 @@ def public_methods(obj): # Server which is run in a process controlled by a manager # - class Server(object): ''' Server class which runs in a process controlled by a manager object @@ -149,7 +143,10 @@ class Server(object): 'debug_info', 'number_of_objects', 'dummy', 'incref', 'decref'] def __init__(self, registry, address, authkey, serializer): - assert isinstance(authkey, bytes) + if not isinstance(authkey, bytes): + raise TypeError( + "Authkey {0!r} is type {1!s}, not bytes".format( + authkey, type(authkey))) self.registry = registry self.authkey = process.AuthenticationString(authkey) Listener, Client = listener_client[serializer] @@ -160,7 +157,8 @@ def __init__(self, registry, address, authkey, serializer): self.id_to_obj = {'0': (None, ())} self.id_to_refcount = {} - self.mutex = threading.RLock() + self.id_to_local_proxy_obj = {} + self.mutex = threading.Lock() def serve_forever(self): ''' @@ -178,7 +176,7 @@ def serve_forever(self): except (KeyboardInterrupt, SystemExit): pass finally: - if sys.stdout != sys.__stdout__: + if sys.stdout != sys.__stdout__: # what about stderr? util.debug('resetting stdout, stderr') sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ @@ -190,7 +188,7 @@ def accepter(self): c = self.listener.accept() except OSError: continue - t = threading.Thread(target=self.handle_request, args=(c, )) + t = threading.Thread(target=self.handle_request, args=(c,)) t.daemon = True t.start() @@ -217,14 +215,14 @@ def handle_request(self, c): msg = ('#RETURN', result) try: c.send(msg) - except Exception as exc: + except Exception as e: try: c.send(('#TRACEBACK', format_exc())) except Exception: pass util.info('Failure to send message: %r', msg) util.info(' ... request was %r', request) - util.info(' ... exception was %r', exc) + util.info(' ... exception was %r', e) c.close() @@ -245,20 +243,27 @@ def serve_client(self, conn): methodname = obj = None request = recv() ident, methodname, args, kwds = request - obj, exposed, gettypeid = id_to_obj[ident] + try: + obj, exposed, gettypeid = id_to_obj[ident] + except KeyError as ke: + try: + obj, exposed, gettypeid = \ + self.id_to_local_proxy_obj[ident] + except KeyError: + raise ke if methodname not in exposed: raise AttributeError( - 'method %r of %r object is not in exposed=%r' % ( - methodname, type(obj), exposed) - ) + 'method %r of %r object is not in exposed=%r' % + (methodname, type(obj), exposed) + ) function = getattr(obj, methodname) try: res = function(*args, **kwds) - except Exception as exc: - msg = ('#ERROR', exc) + except Exception as e: + msg = ('#ERROR', e) else: typeid = gettypeid and gettypeid.get(methodname, None) if typeid: @@ -276,7 +281,7 @@ def serve_client(self, conn): fallback_func = self.fallback_mapping[methodname] result = fallback_func( self, conn, ident, obj, *args, **kwds - ) + ) msg = ('#RETURN', result) except Exception: msg = ('#TRACEBACK', format_exc()) @@ -293,12 +298,12 @@ def serve_client(self, conn): try: send(msg) except Exception: - send(('#UNSERIALIZABLE', repr(msg))) - except Exception as exc: + send(('#UNSERIALIZABLE', format_exc())) + except Exception as e: util.info('exception in thread serving %r', - threading.current_thread().name) + threading.current_thread().name) util.info(' ... message was %r', msg) - util.info(' ... exception was %r', exc) + util.info(' ... exception was %r', e) conn.close() sys.exit(1) @@ -312,10 +317,10 @@ def fallback_repr(self, conn, ident, obj): return repr(obj) fallback_mapping = { - '__str__': fallback_str, - '__repr__': fallback_repr, - '#GETVALUE': fallback_getvalue, - } + '__str__':fallback_str, + '__repr__':fallback_repr, + '#GETVALUE':fallback_getvalue + } def dummy(self, c): pass @@ -324,9 +329,10 @@ def debug_info(self, c): ''' Return some info --- useful to spot problems with refcounting ''' + # Perhaps include debug info about 'c'? with self.mutex: result = [] - keys = list(self.id_to_obj.keys()) + keys = list(self.id_to_refcount.keys()) keys.sort() for ident in keys: if ident != '0': @@ -339,14 +345,15 @@ def number_of_objects(self, c): ''' Number of shared objects ''' - return len(self.id_to_obj) - 1 # don't count ident='0' + # Doesn't use (len(self.id_to_obj) - 1) as we shouldn't count ident='0' + return len(self.id_to_refcount) def shutdown(self, c): ''' Shutdown this process ''' try: - util.debug('Manager received shutdown message') + util.debug('manager received shutdown message') c.send(('#RETURN', None)) except: import traceback @@ -354,16 +361,18 @@ def shutdown(self, c): finally: self.stop_event.set() - def create(self, c, typeid, *args, **kwds): + def create(self, c, typeid, /, *args, **kwds): ''' Create a new shared object and return its id ''' with self.mutex: callable, exposed, method_to_typeid, proxytype = \ - self.registry[typeid] + self.registry[typeid] if callable is None: - assert len(args) == 1 and not kwds + if kwds or (len(args) != 1): + raise ValueError( + "Without callable, must have one non-keyword argument") obj = args[0] else: obj = callable(*args, **kwds) @@ -371,23 +380,22 @@ def create(self, c, typeid, *args, **kwds): if exposed is None: exposed = public_methods(obj) if method_to_typeid is not None: - assert type(method_to_typeid) is dict + if not isinstance(method_to_typeid, dict): + raise TypeError( + "Method_to_typeid {0!r}: type {1!s}, not dict".format( + method_to_typeid, type(method_to_typeid))) exposed = list(exposed) + list(method_to_typeid) - # convert to string because xmlrpclib - # only has 32 bit signed integers - ident = '%x' % id(obj) + + ident = '%x' % id(obj) # convert to string because xmlrpclib + # only has 32 bit signed integers util.debug('%r callable returned object with id %r', typeid, ident) self.id_to_obj[ident] = (obj, set(exposed), method_to_typeid) if ident not in self.id_to_refcount: self.id_to_refcount[ident] = 0 - # increment the reference count immediately, to avoid - # this object being garbage collected before a Proxy - # object for it can be created. The caller of create() - # is responsible for doing a decref once the Proxy object - # has been created. - self.incref(c, ident) - return ident, tuple(exposed) + + self.incref(c, ident) + return ident, tuple(exposed) def get_methods(self, c, token): ''' @@ -405,21 +413,54 @@ def accept_connection(self, c, name): def incref(self, c, ident): with self.mutex: - self.id_to_refcount[ident] += 1 + try: + self.id_to_refcount[ident] += 1 + except KeyError as ke: + # If no external references exist but an internal (to the + # manager) still does and a new external reference is created + # from it, restore the manager's tracking of it from the + # previously stashed internal ref. + if ident in self.id_to_local_proxy_obj: + self.id_to_refcount[ident] = 1 + self.id_to_obj[ident] = \ + self.id_to_local_proxy_obj[ident] + obj, exposed, gettypeid = self.id_to_obj[ident] + util.debug('Server re-enabled tracking & INCREF %r', ident) + else: + raise ke def decref(self, c, ident): + if ident not in self.id_to_refcount and \ + ident in self.id_to_local_proxy_obj: + util.debug('Server DECREF skipping %r', ident) + return + with self.mutex: - assert self.id_to_refcount[ident] >= 1 + if self.id_to_refcount[ident] <= 0: + raise AssertionError( + "Id {0!s} ({1!r}) has refcount {2:n}, not 1+".format( + ident, self.id_to_obj[ident], + self.id_to_refcount[ident])) self.id_to_refcount[ident] -= 1 if self.id_to_refcount[ident] == 0: - del self.id_to_obj[ident], self.id_to_refcount[ident] - util.debug('disposing of obj with id %r', ident) + del self.id_to_refcount[ident] + + if ident not in self.id_to_refcount: + # Two-step process in case the object turns out to contain other + # proxy objects (e.g. a managed list of managed lists). + # Otherwise, deleting self.id_to_obj[ident] would trigger the + # deleting of the stored value (another managed object) which would + # in turn attempt to acquire the mutex that is already held here. + self.id_to_obj[ident] = (None, (), None) # thread-safe + util.debug('disposing of obj with id %r', ident) + with self.mutex: + del self.id_to_obj[ident] + # # Class to represent state of a manager # - class State(object): __slots__ = ['value'] INITIAL = 0 @@ -431,15 +472,14 @@ class State(object): # listener_client = { - 'pickle': (connection.Listener, connection.Client), - 'xmlrpclib': (connection.XmlListener, connection.XmlClient), -} + 'pickle' : (connection.Listener, connection.Client), + 'xmlrpclib' : (connection.XmlListener, connection.XmlClient) + } # # Definition of BaseManager # - class BaseManager(object): ''' Base class for managers @@ -459,15 +499,18 @@ def __init__(self, address=None, authkey=None, serializer='pickle', self._Listener, self._Client = listener_client[serializer] self._ctx = ctx or get_context() - def __reduce__(self): - return (type(self).from_address, - (self._address, self._authkey, self._serializer)) - def get_server(self): ''' Return server object with serve_forever() method and address attribute ''' - assert self._state.value == State.INITIAL + if self._state.value != State.INITIAL: + if self._state.value == State.STARTED: + raise ProcessError("Already started server") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("Manager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) return Server(self._registry, self._address, self._authkey, self._serializer) @@ -484,7 +527,14 @@ def start(self, initializer=None, initargs=()): ''' Spawn a server process for this manager object ''' - assert self._state.value == State.INITIAL + if self._state.value != State.INITIAL: + if self._state.value == State.STARTED: + raise ProcessError("Already started server") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("Manager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) if initializer is not None and not callable(initializer): raise TypeError('initializer must be a callable') @@ -497,9 +547,9 @@ def start(self, initializer=None, initargs=()): target=type(self)._run_server, args=(self._registry, self._address, self._authkey, self._serializer, writer, initializer, initargs), - ) + ) ident = ':'.join(str(i) for i in self._process._identity) - self._process.name = type(self).__name__ + '-' + ident + self._process.name = type(self).__name__ + '-' + ident self._process.start() # get address of server @@ -514,7 +564,7 @@ def start(self, initializer=None, initargs=()): args=(self._process, self._address, self._authkey, self._state, self._Client), exitpriority=0 - ) + ) @classmethod def _run_server(cls, registry, address, authkey, serializer, writer, @@ -522,6 +572,9 @@ def _run_server(cls, registry, address, authkey, serializer, writer, ''' Create a server, report its address and run it ''' + # bpo-36368: protect server process from KeyboardInterrupt signals + signal.signal(signal.SIGINT, signal.SIG_IGN) + if initializer is not None: initializer(*initargs) @@ -536,15 +589,14 @@ def _run_server(cls, registry, address, authkey, serializer, writer, util.info('manager serving at %r', server.address) server.serve_forever() - def _create(self, typeid, *args, **kwds): + def _create(self, typeid, /, *args, **kwds): ''' Create a new shared object; return the token and exposed tuple ''' assert self._state.value == State.STARTED, 'server not yet started' conn = self._Client(self._address, authkey=self._authkey) try: - id, exposed = dispatch(conn, None, 'create', - (typeid,) + args, kwds) + id, exposed = dispatch(conn, None, 'create', (typeid,)+args, kwds) finally: conn.close() return Token(typeid, self._address, id), exposed @@ -581,7 +633,14 @@ def _number_of_objects(self): def __enter__(self): if self._state.value == State.INITIAL: self.start() - assert self._state.value == State.STARTED + if self._state.value != State.STARTED: + if self._state.value == State.INITIAL: + raise ProcessError("Unable to start server") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("Manager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -619,7 +678,9 @@ def _finalize_manager(process, address, authkey, state, _Client): except KeyError: pass - address = property(lambda self: self._address) + @property + def address(self): + return self._address @classmethod def register(cls, typeid, callable=None, proxytype=None, exposed=None, @@ -635,28 +696,26 @@ def register(cls, typeid, callable=None, proxytype=None, exposed=None, exposed = exposed or getattr(proxytype, '_exposed_', None) - method_to_typeid = ( - method_to_typeid or - getattr(proxytype, '_method_to_typeid_', None) - ) + method_to_typeid = method_to_typeid or \ + getattr(proxytype, '_method_to_typeid_', None) if method_to_typeid: - for key, value in items(method_to_typeid): + for key, value in list(method_to_typeid.items()): # isinstance? assert type(key) is str, '%r is not a string' % key assert type(value) is str, '%r is not a string' % value cls._registry[typeid] = ( callable, exposed, method_to_typeid, proxytype - ) + ) if create_method: - def temp(self, *args, **kwds): + def temp(self, /, *args, **kwds): util.debug('requesting creation of a shared %r object', typeid) token, exp = self._create(typeid, *args, **kwds) proxy = proxytype( token, self._serializer, manager=self, authkey=self._authkey, exposed=exp - ) + ) conn = self._Client(token.address, authkey=self._authkey) dispatch(conn, None, 'decref', (token.id,)) return proxy @@ -667,12 +726,9 @@ def temp(self, *args, **kwds): # Subclass of set which get cleared after a fork # - class ProcessLocalSet(set): - def __init__(self): util.register_after_fork(self, lambda obj: obj.clear()) - def __reduce__(self): return type(self), () @@ -680,7 +736,6 @@ def __reduce__(self): # Definition of BaseProxy # - class BaseProxy(object): ''' A base for proxies of shared objects @@ -689,7 +744,7 @@ class BaseProxy(object): _mutex = util.ForkAwareThreadLock() def __init__(self, token, serializer, manager=None, - authkey=None, exposed=None, incref=True): + authkey=None, exposed=None, incref=True, manager_owned=False): with BaseProxy._mutex: tls_idset = BaseProxy._address_to_local.get(token.address, None) if tls_idset is None: @@ -711,6 +766,12 @@ def __init__(self, token, serializer, manager=None, self._serializer = serializer self._Client = listener_client[serializer][1] + # Should be set to True only when a proxy object is being created + # on the manager server; primary use case: nested proxy objects. + # RebuildProxy detects when a proxy is being created on the manager + # and sets this value appropriately. + self._owned_by_manager = manager_owned + if authkey is not None: self._authkey = process.AuthenticationString(authkey) elif self._manager is not None: @@ -734,7 +795,7 @@ def _connect(self): def _callmethod(self, methodname, args=(), kwds={}): ''' - Try to call a method of the referrent and return a copy of the result + Try to call a method of the referent and return a copy of the result ''' try: conn = self._tls.connection @@ -756,7 +817,7 @@ def _callmethod(self, methodname, args=(), kwds={}): proxy = proxytype( token, self._serializer, manager=self._manager, authkey=self._authkey, exposed=exposed - ) + ) conn = self._Client(token.address, authkey=self._authkey) dispatch(conn, None, 'decref', (token.id,)) return proxy @@ -769,6 +830,10 @@ def _getvalue(self): return self._callmethod('#GETVALUE') def _incref(self): + if self._owned_by_manager: + util.debug('owned_by_manager skipped INCREF of %r', self._token.id) + return + conn = self._Client(self._token.address, authkey=self._authkey) dispatch(conn, None, 'incref', (self._id,)) util.debug('INCREF %r', self._token.id) @@ -782,7 +847,7 @@ def _incref(self): args=(self._token, self._authkey, state, self._tls, self._idset, self._Client), exitpriority=10 - ) + ) @staticmethod def _decref(token, authkey, state, tls, idset, _Client): @@ -795,8 +860,8 @@ def _decref(token, authkey, state, tls, idset, _Client): util.debug('DECREF %r', token.id) conn = _Client(token.address, authkey=authkey) dispatch(conn, None, 'decref', (token.id,)) - except Exception as exc: - util.debug('... decref failed %s', exc) + except Exception as e: + util.debug('... decref failed %s', e) else: util.debug('DECREF %r -- manager already shutdown', token.id) @@ -813,13 +878,13 @@ def _after_fork(self): self._manager = None try: self._incref() - except Exception as exc: + except Exception as e: # the proxy may just be for a manager which has shutdown - util.info('incref failed: %s', exc) + util.info('incref failed: %s' % e) def __reduce__(self): kwds = {} - if context.get_spawning_popen() is not None: + if get_spawning_popen() is not None: kwds['authkey'] = self._authkey if getattr(self, '_isauto', False): @@ -850,32 +915,30 @@ def __str__(self): # Function used for unpickling # - def RebuildProxy(func, token, serializer, kwds): ''' Function used for unpickling proxy objects. - - If possible the shared object is returned, or otherwise a proxy for it. ''' server = getattr(process.current_process(), '_manager_server', None) - if server and server.address == token.address: - return server.id_to_obj[token.id][0] - else: - incref = ( - kwds.pop('incref', True) and - not getattr(process.current_process(), '_inheriting', False) + util.debug('Rebuild a proxy owned by manager, token=%r', token) + kwds['manager_owned'] = True + if token.id not in server.id_to_local_proxy_obj: + server.id_to_local_proxy_obj[token.id] = \ + server.id_to_obj[token.id] + incref = ( + kwds.pop('incref', True) and + not getattr(process.current_process(), '_inheriting', False) ) - return func(token, serializer, incref=incref, **kwds) + return func(token, serializer, incref=incref, **kwds) # # Functions to create proxies and proxy types # - def MakeProxyType(name, exposed, _cache={}): ''' - Return an proxy type whose methods are given by `exposed` + Return a proxy type whose methods are given by `exposed` ''' exposed = tuple(exposed) try: @@ -886,7 +949,7 @@ def MakeProxyType(name, exposed, _cache={}): dic = {} for meth in exposed: - exec('''def %s(self, *args, **kwds): + exec('''def %s(self, /, *args, **kwds): return self._callmethod(%r, args, kwds)''' % (meth, meth), dic) ProxyType = type(name, (BaseProxy,), dic) @@ -924,12 +987,9 @@ def AutoProxy(token, serializer, manager=None, authkey=None, # Types/callables which we will register with SyncManager # - class Namespace(object): - - def __init__(self, **kwds): + def __init__(self, /, **kwds): self.__dict__.update(kwds) - def __repr__(self): items = list(self.__dict__.items()) temp = [] @@ -939,25 +999,18 @@ def __repr__(self): temp.sort() return '%s(%s)' % (self.__class__.__name__, ', '.join(temp)) - class Value(object): - def __init__(self, typecode, value, lock=True): self._typecode = typecode self._value = value - def get(self): return self._value - def set(self, value): self._value = value - def __repr__(self): - return '%s(%r, %r)' % (type(self).__name__, - self._typecode, self._value) + return '%s(%r, %r)'%(type(self).__name__, self._typecode, self._value) value = property(get, set) - def Array(typecode, sequence, lock=True): return array.array(typecode, sequence) @@ -965,73 +1018,53 @@ def Array(typecode, sequence, lock=True): # Proxy types used by SyncManager # - class IteratorProxy(BaseProxy): - if sys.version_info[0] == 3: - _exposed = ('__next__', 'send', 'throw', 'close') - else: - _exposed_ = ('__next__', 'next', 'send', 'throw', 'close') - - def next(self, *args): - return self._callmethod('next', args) - + _exposed_ = ('__next__', 'send', 'throw', 'close') def __iter__(self): return self - def __next__(self, *args): return self._callmethod('__next__', args) - def send(self, *args): return self._callmethod('send', args) - def throw(self, *args): return self._callmethod('throw', args) - def close(self, *args): return self._callmethod('close', args) class AcquirerProxy(BaseProxy): _exposed_ = ('acquire', 'release') - def acquire(self, blocking=True, timeout=None): - args = (blocking, ) if timeout is None else (blocking, timeout) + args = (blocking,) if timeout is None else (blocking, timeout) return self._callmethod('acquire', args) - def release(self): return self._callmethod('release') - def __enter__(self): return self._callmethod('acquire') - def __exit__(self, exc_type, exc_val, exc_tb): return self._callmethod('release') class ConditionProxy(AcquirerProxy): _exposed_ = ('acquire', 'release', 'wait', 'notify', 'notify_all') - def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) - - def notify(self): - return self._callmethod('notify') - + def notify(self, n=1): + return self._callmethod('notify', (n,)) def notify_all(self): return self._callmethod('notify_all') - def wait_for(self, predicate, timeout=None): result = predicate() if result: return result if timeout is not None: - endtime = monotonic() + timeout + endtime = time.monotonic() + timeout else: endtime = None waittime = None while not result: if endtime is not None: - waittime = endtime - monotonic() + waittime = endtime - time.monotonic() if waittime <= 0: break self.wait(waittime) @@ -1041,60 +1074,47 @@ def wait_for(self, predicate, timeout=None): class EventProxy(BaseProxy): _exposed_ = ('is_set', 'set', 'clear', 'wait') - def is_set(self): return self._callmethod('is_set') - def set(self): return self._callmethod('set') - def clear(self): return self._callmethod('clear') - def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) class BarrierProxy(BaseProxy): _exposed_ = ('__getattribute__', 'wait', 'abort', 'reset') - def wait(self, timeout=None): - return self._callmethod('wait', (timeout, )) - + return self._callmethod('wait', (timeout,)) def abort(self): return self._callmethod('abort') - def reset(self): return self._callmethod('reset') - @property def parties(self): - return self._callmethod('__getattribute__', ('parties', )) - + return self._callmethod('__getattribute__', ('parties',)) @property def n_waiting(self): - return self._callmethod('__getattribute__', ('n_waiting', )) - + return self._callmethod('__getattribute__', ('n_waiting',)) @property def broken(self): - return self._callmethod('__getattribute__', ('broken', )) + return self._callmethod('__getattribute__', ('broken',)) class NamespaceProxy(BaseProxy): _exposed_ = ('__getattribute__', '__setattr__', '__delattr__') - def __getattr__(self, key): if key[0] == '_': return object.__getattribute__(self, key) callmethod = object.__getattribute__(self, '_callmethod') return callmethod('__getattribute__', (key,)) - def __setattr__(self, key, value): if key[0] == '_': return object.__setattr__(self, key, value) callmethod = object.__getattribute__(self, '_callmethod') return callmethod('__setattr__', (key, value)) - def __delattr__(self, key): if key[0] == '_': return object.__delattr__(self, key) @@ -1104,78 +1124,66 @@ def __delattr__(self, key): class ValueProxy(BaseProxy): _exposed_ = ('get', 'set') - def get(self): return self._callmethod('get') - def set(self, value): return self._callmethod('set', (value,)) value = property(get, set) + __class_getitem__ = classmethod(types.GenericAlias) + -_ListProxy_Attributes = ( +BaseListProxy = MakeProxyType('BaseListProxy', ( '__add__', '__contains__', '__delitem__', '__getitem__', '__len__', '__mul__', '__reversed__', '__rmul__', '__setitem__', 'append', 'count', 'extend', 'index', 'insert', 'pop', 'remove', - 'reverse', 'sort', '__imul__', -) -if not PY3: - _ListProxy_Attributes += ('__getslice__', '__setslice__', '__delslice__') -BaseListProxy = MakeProxyType('BaseListProxy', _ListProxy_Attributes) - - + 'reverse', 'sort', '__imul__' + )) class ListProxy(BaseListProxy): - def __iadd__(self, value): self._callmethod('extend', (value,)) return self - def __imul__(self, value): self._callmethod('__imul__', (value,)) return self DictProxy = MakeProxyType('DictProxy', ( - '__contains__', '__delitem__', '__getitem__', '__len__', - '__setitem__', 'clear', 'copy', 'get', 'has_key', 'items', - 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values', -)) + '__contains__', '__delitem__', '__getitem__', '__iter__', '__len__', + '__setitem__', 'clear', 'copy', 'get', 'items', + 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values' + )) +DictProxy._method_to_typeid_ = { + '__iter__': 'Iterator', + } -_ArrayProxy_Attributes = ( - '__len__', '__getitem__', '__setitem__', -) -if not PY3: - _ArrayProxy_Attributes += ('__getslice__', '__setslice__') -ArrayProxy = MakeProxyType('ArrayProxy', _ArrayProxy_Attributes) +ArrayProxy = MakeProxyType('ArrayProxy', ( + '__len__', '__getitem__', '__setitem__' + )) BasePoolProxy = MakeProxyType('PoolProxy', ( 'apply', 'apply_async', 'close', 'imap', 'imap_unordered', 'join', 'map', 'map_async', 'starmap', 'starmap_async', 'terminate', -)) + )) BasePoolProxy._method_to_typeid_ = { 'apply_async': 'AsyncResult', 'map_async': 'AsyncResult', 'starmap_async': 'AsyncResult', 'imap': 'Iterator', - 'imap_unordered': 'Iterator', -} - - + 'imap_unordered': 'Iterator' + } class PoolProxy(BasePoolProxy): def __enter__(self): return self - - def __exit__(self, *exc_info): + def __exit__(self, exc_type, exc_val, exc_tb): self.terminate() - # # Definition of SyncManager # - class SyncManager(BaseManager): ''' Subclass of `BaseManager` which supports a number of shared object types. @@ -1183,12 +1191,12 @@ class SyncManager(BaseManager): The types registered are those intended for the synchronization of threads, plus `dict`, `list` and `Namespace`. - The `billiard.Manager()` function creates started instances of + The `multiprocessing.Manager()` function creates started instances of this class. ''' -SyncManager.register('Queue', Queue) -SyncManager.register('JoinableQueue', Queue) +SyncManager.register('Queue', queue.Queue) +SyncManager.register('JoinableQueue', queue.Queue) SyncManager.register('Event', threading.Event, EventProxy) SyncManager.register('Lock', threading.Lock, AcquirerProxy) SyncManager.register('RLock', threading.RLock, AcquirerProxy) @@ -1196,8 +1204,7 @@ class SyncManager(BaseManager): SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore, AcquirerProxy) SyncManager.register('Condition', threading.Condition, ConditionProxy) -if hasattr(threading, 'Barrier'): # PY3 - SyncManager.register('Barrier', threading.Barrier, BarrierProxy) +SyncManager.register('Barrier', threading.Barrier, BarrierProxy) SyncManager.register('Pool', pool.Pool, PoolProxy) SyncManager.register('list', list, ListProxy) SyncManager.register('dict', dict, DictProxy) @@ -1208,3 +1215,155 @@ class SyncManager(BaseManager): # types returned by methods of PoolProxy SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False) SyncManager.register('AsyncResult', create_method=False) + +# +# Definition of SharedMemoryManager and SharedMemoryServer +# + +if HAS_SHMEM: + class _SharedMemoryTracker: + "Manages one or more shared memory segments." + + def __init__(self, name, segment_names=[]): + self.shared_memory_context_name = name + self.segment_names = segment_names + + def register_segment(self, segment_name): + "Adds the supplied shared memory block name to tracker." + util.debug(f"Register segment {segment_name!r} in pid {getpid()}") + self.segment_names.append(segment_name) + + def destroy_segment(self, segment_name): + """Calls unlink() on the shared memory block with the supplied name + and removes it from the list of blocks being tracked.""" + util.debug(f"Destroy segment {segment_name!r} in pid {getpid()}") + self.segment_names.remove(segment_name) + segment = shared_memory.SharedMemory(segment_name) + segment.close() + segment.unlink() + + def unlink(self): + "Calls destroy_segment() on all tracked shared memory blocks." + for segment_name in self.segment_names[:]: + self.destroy_segment(segment_name) + + def __del__(self): + util.debug(f"Call {self.__class__.__name__}.__del__ in {getpid()}") + self.unlink() + + def __getstate__(self): + return (self.shared_memory_context_name, self.segment_names) + + def __setstate__(self, state): + self.__init__(*state) + + + class SharedMemoryServer(Server): + + public = Server.public + \ + ['track_segment', 'release_segment', 'list_segments'] + + def __init__(self, *args, **kwargs): + Server.__init__(self, *args, **kwargs) + address = self.address + # The address of Linux abstract namespaces can be bytes + if isinstance(address, bytes): + address = os.fsdecode(address) + self.shared_memory_context = \ + _SharedMemoryTracker(f"shm_{address}_{getpid()}") + util.debug(f"SharedMemoryServer started by pid {getpid()}") + + def create(self, c, typeid, /, *args, **kwargs): + """Create a new distributed-shared object (not backed by a shared + memory block) and return its id to be used in a Proxy Object.""" + # Unless set up as a shared proxy, don't make shared_memory_context + # a standard part of kwargs. This makes things easier for supplying + # simple functions. + if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"): + kwargs['shared_memory_context'] = self.shared_memory_context + return Server.create(self, c, typeid, *args, **kwargs) + + def shutdown(self, c): + "Call unlink() on all tracked shared memory, terminate the Server." + self.shared_memory_context.unlink() + return Server.shutdown(self, c) + + def track_segment(self, c, segment_name): + "Adds the supplied shared memory block name to Server's tracker." + self.shared_memory_context.register_segment(segment_name) + + def release_segment(self, c, segment_name): + """Calls unlink() on the shared memory block with the supplied name + and removes it from the tracker instance inside the Server.""" + self.shared_memory_context.destroy_segment(segment_name) + + def list_segments(self, c): + """Returns a list of names of shared memory blocks that the Server + is currently tracking.""" + return self.shared_memory_context.segment_names + + + class SharedMemoryManager(BaseManager): + """Like SyncManager but uses SharedMemoryServer instead of Server. + + It provides methods for creating and returning SharedMemory instances + and for creating a list-like object (ShareableList) backed by shared + memory. It also provides methods that create and return Proxy Objects + that support synchronization across processes (i.e. multi-process-safe + locks and semaphores). + """ + + _Server = SharedMemoryServer + + def __init__(self, *args, **kwargs): + if os.name == "posix": + # bpo-36867: Ensure the resource_tracker is running before + # launching the manager process, so that concurrent + # shared_memory manipulation both in the manager and in the + # current process does not create two resource_tracker + # processes. + from . import resource_tracker + resource_tracker.ensure_running() + BaseManager.__init__(self, *args, **kwargs) + util.debug(f"{self.__class__.__name__} created by pid {getpid()}") + + def __del__(self): + util.debug(f"{self.__class__.__name__}.__del__ by pid {getpid()}") + pass + + def get_server(self): + 'Better than monkeypatching for now; merge into Server ultimately' + if self._state.value != State.INITIAL: + if self._state.value == State.STARTED: + raise ProcessError("Already started SharedMemoryServer") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("SharedMemoryManager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) + return self._Server(self._registry, self._address, + self._authkey, self._serializer) + + def SharedMemory(self, size): + """Returns a new SharedMemory instance with the specified size in + bytes, to be tracked by the manager.""" + with self._Client(self._address, authkey=self._authkey) as conn: + sms = shared_memory.SharedMemory(None, create=True, size=size) + try: + dispatch(conn, None, 'track_segment', (sms.name,)) + except BaseException as e: + sms.unlink() + raise e + return sms + + def ShareableList(self, sequence): + """Returns a new ShareableList instance populated with the values + from the input sequence, to be tracked by the manager.""" + with self._Client(self._address, authkey=self._authkey) as conn: + sl = shared_memory.ShareableList(sequence) + try: + dispatch(conn, None, 'track_segment', (sl.shm.name,)) + except BaseException as e: + sl.shm.unlink() + raise e + return sl diff --git a/billiard/pool.py b/billiard/pool.py index 04c7f4a7..bbe05a55 100644 --- a/billiard/pool.py +++ b/billiard/pool.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Module providing the `Pool` class for managing a process pool # @@ -7,203 +6,77 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import + +__all__ = ['Pool', 'ThreadPool'] # # Imports # -import copy -import errno + +import collections import itertools import os -import platform -import signal -import sys +import queue import threading import time +import traceback +import types import warnings -from collections import deque -from functools import partial - -from . import cpu_count, get_context +# If threading is available then ThreadPool should be provided. Therefore +# we avoid top-level imports which are liable to fail on some systems. from . import util -from .common import ( - TERM_SIGNAL, human_status, pickle_loads, reset_signals, restart_state, -) -from .compat import get_errno, mem_rss, send_offset -from .einfo import ExceptionInfo -from .dummy import DummyProcess -from .exceptions import ( - CoroStop, - RestartFreqExceeded, - SoftTimeLimitExceeded, - Terminated, - TimeLimitExceeded, - TimeoutError, - WorkerLostError, -) -from .five import Empty, Queue, range, values, reraise, monotonic -from .util import Finalize, debug, warning - -MAXMEM_USED_FMT = """\ -child process exiting after exceeding memory limit ({0}KiB / {1}KiB) -""" - -PY3 = sys.version_info[0] == 3 - -if platform.system() == 'Windows': # pragma: no cover - # On Windows os.kill calls TerminateProcess which cannot be - # handled by # any process, so this is needed to terminate the task - # *and its children* (if any). - from ._win import kill_processtree as _kill # noqa - SIGKILL = TERM_SIGNAL -else: - from os import kill as _kill # noqa - SIGKILL = signal.SIGKILL - - -try: - TIMEOUT_MAX = threading.TIMEOUT_MAX -except AttributeError: # pragma: no cover - TIMEOUT_MAX = 1e10 # noqa - - -if sys.version_info >= (3, 3): - _Semaphore = threading.Semaphore -else: - # Semaphore is a factory function pointing to _Semaphore - _Semaphore = threading._Semaphore # noqa +from . import get_context, TimeoutError +from .connection import wait # # Constants representing the state of a pool # -RUN = 0 -CLOSE = 1 -TERMINATE = 2 - -# -# Constants representing the state of a job -# - -ACK = 0 -READY = 1 -TASK = 2 -NACK = 3 -DEATH = 4 - -# -# Exit code constants -# -EX_OK = 0 -EX_FAILURE = 1 -EX_RECYCLE = 0x9B - - -# Signal used for soft time limits. -SIG_SOFT_TIMEOUT = getattr(signal, "SIGUSR1", None) +INIT = "INIT" +RUN = "RUN" +CLOSE = "CLOSE" +TERMINATE = "TERMINATE" # # Miscellaneous # -LOST_WORKER_TIMEOUT = 10.0 -EX_OK = getattr(os, "EX_OK", 0) -GUARANTEE_MESSAGE_CONSUMPTION_RETRY_LIMIT = 300 -GUARANTEE_MESSAGE_CONSUMPTION_RETRY_INTERVAL = 0.1 - job_counter = itertools.count() -Lock = threading.Lock - - -def _get_send_offset(connection): - try: - native = connection.send_offset - except AttributeError: - native = None - if native is None: - return partial(send_offset, connection.fileno()) - return native - - def mapstar(args): return list(map(*args)) - def starmapstar(args): return list(itertools.starmap(args[0], args[1])) +# +# Hack to embed stringification of remote traceback in local traceback +# -def error(msg, *args, **kwargs): - util.get_logger().error(msg, *args, **kwargs) - - -def stop_if_not_current(thread, timeout=None): - if thread is not threading.current_thread(): - thread.stop(timeout) - - -class LaxBoundedSemaphore(_Semaphore): - """Semaphore that checks that # release is <= # acquires, - but ignores if # releases >= value.""" - - def shrink(self): - self._initial_value -= 1 - self.acquire() - - if PY3: - - def __init__(self, value=1, verbose=None): - _Semaphore.__init__(self, value) - self._initial_value = value - - def grow(self): - with self._cond: - self._initial_value += 1 - self._value += 1 - self._cond.notify() +class RemoteTraceback(Exception): + def __init__(self, tb): + self.tb = tb + def __str__(self): + return self.tb + +class ExceptionWithTraceback: + def __init__(self, exc, tb): + tb = traceback.format_exception(type(exc), exc, tb) + tb = ''.join(tb) + self.exc = exc + self.tb = '\n"""\n%s"""' % tb + def __reduce__(self): + return rebuild_exc, (self.exc, self.tb) - def release(self): - cond = self._cond - with cond: - if self._value < self._initial_value: - self._value += 1 - cond.notify_all() - - def clear(self): - while self._value < self._initial_value: - _Semaphore.release(self) - else: - - def __init__(self, value=1, verbose=None): - _Semaphore.__init__(self, value, verbose) - self._initial_value = value - - def grow(self): - cond = self._Semaphore__cond - with cond: - self._initial_value += 1 - self._Semaphore__value += 1 - cond.notify() - - def release(self): # noqa - cond = self._Semaphore__cond - with cond: - if self._Semaphore__value < self._initial_value: - self._Semaphore__value += 1 - cond.notifyAll() - - def clear(self): # noqa - while self._Semaphore__value < self._initial_value: - _Semaphore.release(self) +def rebuild_exc(exc, tb): + exc.__cause__ = RemoteTraceback(tb) + return exc # -# Exceptions +# Code run by worker processes # - class MaybeEncodingError(Exception): """Wraps possible unpickleable errors, so they can be safely sent through the socket.""" @@ -213,1152 +86,258 @@ def __init__(self, exc, value): self.value = repr(value) super(MaybeEncodingError, self).__init__(self.exc, self.value) - def __repr__(self): - return "<%s: %s>" % (self.__class__.__name__, str(self)) - def __str__(self): - return "Error sending result: '%r'. Reason: '%r'." % ( - self.value, self.exc) - - -class WorkersJoined(Exception): - """All workers have terminated.""" - - -def soft_timeout_sighandler(signum, frame): - raise SoftTimeLimitExceeded() - -# -# Code run by worker processes -# - + return "Error sending result: '%s'. Reason: '%s'" % (self.value, + self.exc) -class Worker(object): - - def __init__(self, inq, outq, synq=None, initializer=None, initargs=(), - maxtasks=None, sentinel=None, on_exit=None, - sigprotection=True, wrap_exception=True, - max_memory_per_child=None, on_ready_counter=None): - assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) - self.initializer = initializer - self.initargs = initargs - self.maxtasks = maxtasks - self.max_memory_per_child = max_memory_per_child - self._shutdown = sentinel - self.on_exit = on_exit - self.sigprotection = sigprotection - self.inq, self.outq, self.synq = inq, outq, synq - self.wrap_exception = wrap_exception # XXX cannot disable yet - self.on_ready_counter = on_ready_counter - self.contribute_to_object(self) - - def contribute_to_object(self, obj): - obj.inq, obj.outq, obj.synq = self.inq, self.outq, self.synq - obj.inqW_fd = self.inq._writer.fileno() # inqueue write fd - obj.outqR_fd = self.outq._reader.fileno() # outqueue read fd - if self.synq: - obj.synqR_fd = self.synq._reader.fileno() # synqueue read fd - obj.synqW_fd = self.synq._writer.fileno() # synqueue write fd - obj.send_syn_offset = _get_send_offset(self.synq._writer) - else: - obj.synqR_fd = obj.synqW_fd = obj._send_syn_offset = None - obj._quick_put = self.inq._writer.send - obj._quick_get = self.outq._reader.recv - obj.send_job_offset = _get_send_offset(self.inq._writer) - return obj - - def __reduce__(self): - return self.__class__, ( - self.inq, self.outq, self.synq, self.initializer, - self.initargs, self.maxtasks, self._shutdown, self.on_exit, - self.sigprotection, self.wrap_exception, self.max_memory_per_child, - ) + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self) - def __call__(self): - _exit = sys.exit - _exitcode = [None] - def exit(status=None): - _exitcode[0] = status - return _exit(status) - sys.exit = exit +def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, + wrap_exception=False): + if (maxtasks is not None) and not (isinstance(maxtasks, int) + and maxtasks >= 1): + raise AssertionError("Maxtasks {!r} is not valid".format(maxtasks)) + put = outqueue.put + get = inqueue.get + if hasattr(inqueue, '_writer'): + inqueue._writer.close() + outqueue._reader.close() - pid = os.getpid() + if initializer is not None: + initializer(*initargs) - self._make_child_methods() - self.after_fork() - self.on_loop_start(pid=pid) # callback on loop start + completed = 0 + while maxtasks is None or (maxtasks and completed < maxtasks): try: - sys.exit(self.workloop(pid=pid)) - except Exception as exc: - error('Pool process %r error: %r', self, exc, exc_info=1) - self._do_exit(pid, _exitcode[0], exc) - finally: - self._do_exit(pid, _exitcode[0], None) - - def _do_exit(self, pid, exitcode, exc=None): - if exitcode is None: - exitcode = EX_FAILURE if exc else EX_OK + task = get() + except (EOFError, OSError): + util.debug('worker got EOFError or OSError -- exiting') + break - if self.on_exit is not None: - self.on_exit(pid, exitcode) + if task is None: + util.debug('worker got sentinel -- exiting') + break - if sys.platform != 'win32': - try: - self.outq.put((DEATH, (pid, exitcode))) - time.sleep(1) - finally: - os._exit(exitcode) - else: - os._exit(exitcode) - - def on_loop_start(self, pid): - pass - - def prepare_result(self, result): - return result - - def workloop(self, debug=debug, now=monotonic, pid=None): - pid = pid or os.getpid() - put = self.outq.put - inqW_fd = self.inqW_fd - synqW_fd = self.synqW_fd - maxtasks = self.maxtasks - max_memory_per_child = self.max_memory_per_child or 0 - prepare_result = self.prepare_result - - wait_for_job = self.wait_for_job - _wait_for_syn = self.wait_for_syn - - def wait_for_syn(jid): - i = 0 - while 1: - if i > 60: - error('!!!WAIT FOR ACK TIMEOUT: job:%r fd:%r!!!', - jid, self.synq._reader.fileno(), exc_info=1) - req = _wait_for_syn() - if req: - type_, args = req - if type_ == NACK: - return False - assert type_ == ACK - return True - i += 1 - - completed = 0 + job, i, func, args, kwds = task try: - while maxtasks is None or (maxtasks and completed < maxtasks): - req = wait_for_job() - if req: - type_, args_ = req - assert type_ == TASK - job, i, fun, args, kwargs = args_ - put((ACK, (job, i, now(), pid, synqW_fd))) - if _wait_for_syn: - confirm = wait_for_syn(job) - if not confirm: - continue # received NACK - try: - result = (True, prepare_result(fun(*args, **kwargs))) - except Exception: - result = (False, ExceptionInfo()) - try: - put((READY, (job, i, result, inqW_fd))) - except Exception as exc: - _, _, tb = sys.exc_info() - try: - wrapped = MaybeEncodingError(exc, result[1]) - einfo = ExceptionInfo(( - MaybeEncodingError, wrapped, tb, - )) - put((READY, (job, i, (False, einfo), inqW_fd))) - finally: - del(tb) - completed += 1 - if max_memory_per_child > 0: - used_kb = mem_rss() - if used_kb <= 0: - error('worker unable to determine memory usage') - if used_kb > 0 and used_kb > max_memory_per_child: - warning(MAXMEM_USED_FMT.format( - used_kb, max_memory_per_child)) - return EX_RECYCLE - - debug('worker exiting after %d tasks', completed) - if maxtasks: - return EX_RECYCLE if completed == maxtasks else EX_FAILURE - return EX_OK - finally: - # Before exiting the worker, we want to ensure that that all - # messages produced by the worker have been consumed by the main - # process. This prevents the worker being terminated prematurely - # and messages being lost. - self._ensure_messages_consumed(completed=completed) - - def _ensure_messages_consumed(self, completed): - """ Returns true if all messages sent out have been received and - consumed within a reasonable amount of time """ - - if not self.on_ready_counter: - return False - - for retry in range(GUARANTEE_MESSAGE_CONSUMPTION_RETRY_LIMIT): - if self.on_ready_counter.value >= completed: - debug('ensured messages consumed after %d retries', retry) - return True - time.sleep(GUARANTEE_MESSAGE_CONSUMPTION_RETRY_INTERVAL) - warning('could not ensure all messages were consumed prior to ' - 'exiting') - return False - - def after_fork(self): - if hasattr(self.inq, '_writer'): - self.inq._writer.close() - if hasattr(self.outq, '_reader'): - self.outq._reader.close() - - if self.initializer is not None: - self.initializer(*self.initargs) - - # Make sure all exiting signals call finally: blocks. - # This is important for the semaphore to be released. - reset_signals(full=self.sigprotection) - - # install signal handler for soft timeouts. - if SIG_SOFT_TIMEOUT is not None: - signal.signal(SIG_SOFT_TIMEOUT, soft_timeout_sighandler) - + result = (True, func(*args, **kwds)) + except Exception as e: + if wrap_exception and func is not _helper_reraises_exception: + e = ExceptionWithTraceback(e, e.__traceback__) + result = (False, e) try: - signal.signal(signal.SIGINT, signal.SIG_IGN) - except AttributeError: - pass - - def _make_recv_method(self, conn): - get = conn.get + put((job, i, result)) + except Exception as e: + wrapped = MaybeEncodingError(e, result[1]) + util.debug("Possible encoding error while sending result: %s" % ( + wrapped)) + put((job, i, (False, wrapped))) - if hasattr(conn, '_reader'): - _poll = conn._reader.poll - if hasattr(conn, 'get_payload') and conn.get_payload: - get_payload = conn.get_payload - - def _recv(timeout, loads=pickle_loads): - return True, loads(get_payload()) - else: - def _recv(timeout): # noqa - if _poll(timeout): - return True, get() - return False, None - else: - def _recv(timeout): # noqa - try: - return True, get(timeout=timeout) - except Queue.Empty: - return False, None - return _recv - - def _make_child_methods(self, loads=pickle_loads): - self.wait_for_job = self._make_protected_receive(self.inq) - self.wait_for_syn = (self._make_protected_receive(self.synq) - if self.synq else None) - - def _make_protected_receive(self, conn): - _receive = self._make_recv_method(conn) - should_shutdown = self._shutdown.is_set if self._shutdown else None - - def receive(debug=debug): - if should_shutdown and should_shutdown(): - debug('worker got sentinel -- exiting') - raise SystemExit(EX_OK) - try: - ready, req = _receive(1.0) - if not ready: - return None - except (EOFError, IOError) as exc: - if get_errno(exc) == errno.EINTR: - return None # interrupted, maybe by gdb - debug('worker got %s -- exiting', type(exc).__name__) - raise SystemExit(EX_FAILURE) - if req is None: - debug('worker got sentinel -- exiting') - raise SystemExit(EX_FAILURE) - return req - - return receive + task = job = result = func = args = kwds = None + completed += 1 + util.debug('worker exiting after %d tasks' % completed) +def _helper_reraises_exception(ex): + 'Pickle-able helper function for use by _guarded_task_generation.' + raise ex # # Class representing a process pool # - -class PoolThread(DummyProcess): - - def __init__(self, *args, **kwargs): - DummyProcess.__init__(self) - self._state = RUN - self._was_started = False - self.daemon = True - - def run(self): - try: - return self.body() - except RestartFreqExceeded as exc: - error("Thread %r crashed: %r", type(self).__name__, exc, - exc_info=1) - _kill(os.getpid(), TERM_SIGNAL) - sys.exit() - except Exception as exc: - error("Thread %r crashed: %r", type(self).__name__, exc, - exc_info=1) - os._exit(1) - - def start(self, *args, **kwargs): - self._was_started = True - super(PoolThread, self).start(*args, **kwargs) - - def on_stop_not_started(self): - pass - - def stop(self, timeout=None): - if self._was_started: - self.join(timeout) - return - self.on_stop_not_started() - - def terminate(self): - self._state = TERMINATE - - def close(self): - self._state = CLOSE - - -class Supervisor(PoolThread): - - def __init__(self, pool): - self.pool = pool - super(Supervisor, self).__init__() - - def body(self): - debug('worker handler starting') - - time.sleep(0.8) - - pool = self.pool - - try: - # do a burst at startup to verify that we can start - # our pool processes, and in that time we lower - # the max restart frequency. - prev_state = pool.restart_state - pool.restart_state = restart_state(10 * pool._processes, 1) - for _ in range(10): - if self._state == RUN and pool._state == RUN: - pool._maintain_pool() - time.sleep(0.1) - - # Keep maintaing workers until the cache gets drained, unless - # the pool is termianted - pool.restart_state = prev_state - while self._state == RUN and pool._state == RUN: - pool._maintain_pool() - time.sleep(0.8) - except RestartFreqExceeded: - pool.close() - pool.join() - raise - debug('worker handler exiting') - - -class TaskHandler(PoolThread): - - def __init__(self, taskqueue, put, outqueue, pool, cache): - self.taskqueue = taskqueue - self.put = put - self.outqueue = outqueue - self.pool = pool - self.cache = cache - super(TaskHandler, self).__init__() - - def body(self): - cache = self.cache - taskqueue = self.taskqueue - put = self.put - - for taskseq, set_length in iter(taskqueue.get, None): - task = None - i = -1 - try: - for i, task in enumerate(taskseq): - if self._state: - debug('task handler found thread._state != RUN') - break - try: - put(task) - except IOError: - debug('could not put task on queue') - break - except Exception: - job, ind = task[:2] - try: - cache[job]._set(ind, (False, ExceptionInfo())) - except KeyError: - pass - else: - if set_length: - debug('doing set_length()') - set_length(i + 1) - continue - break - except Exception: - job, ind = task[:2] if task else (0, 0) - if job in cache: - cache[job]._set(ind + 1, (False, ExceptionInfo())) - if set_length: - util.debug('doing set_length()') - set_length(i + 1) - else: - debug('task handler got sentinel') - - self.tell_others() - - def tell_others(self): - outqueue = self.outqueue - put = self.put - pool = self.pool - - try: - # tell result handler to finish when cache is empty - debug('task handler sending sentinel to result handler') - outqueue.put(None) - - # tell workers there is no more work - debug('task handler sending sentinel to workers') - for p in pool: - put(None) - except IOError: - debug('task handler got IOError when sending sentinels') - - debug('task handler exiting') - - def on_stop_not_started(self): - self.tell_others() - - -class TimeoutHandler(PoolThread): - - def __init__(self, processes, cache, t_soft, t_hard): - self.processes = processes - self.cache = cache - self.t_soft = t_soft - self.t_hard = t_hard - self._it = None - super(TimeoutHandler, self).__init__() - - def _process_by_pid(self, pid): - return next(( - (proc, i) for i, proc in enumerate(self.processes) - if proc.pid == pid - ), (None, None)) - - def on_soft_timeout(self, job): - debug('soft time limit exceeded for %r', job) - process, _index = self._process_by_pid(job._worker_pid) - if not process: - return - - # Run timeout callback - job.handle_timeout(soft=True) - - try: - _kill(job._worker_pid, SIG_SOFT_TIMEOUT) - except OSError as exc: - if get_errno(exc) != errno.ESRCH: - raise - - def on_hard_timeout(self, job): - if job.ready(): - return - debug('hard time limit exceeded for %r', job) - # Remove from cache and set return value to an exception - try: - raise TimeLimitExceeded(job._timeout) - except TimeLimitExceeded: - job._set(job._job, (False, ExceptionInfo())) - else: # pragma: no cover - pass - - # Remove from _pool - process, _index = self._process_by_pid(job._worker_pid) - - # Run timeout callback - job.handle_timeout(soft=False) - - if process: - self._trywaitkill(process) - - def _trywaitkill(self, worker): - debug('timeout: sending TERM to %s', worker._name) - try: - if os.getpgid(worker.pid) == worker.pid: - debug("worker %s is a group leader. It is safe to kill (SIGTERM) the whole group", worker.pid) - os.killpg(os.getpgid(worker.pid), signal.SIGTERM) - else: - worker.terminate() - except OSError: - pass - else: - if worker._popen.wait(timeout=0.1): - return - debug('timeout: TERM timed-out, now sending KILL to %s', worker._name) - try: - if os.getpgid(worker.pid) == worker.pid: - debug("worker %s is a group leader. It is safe to kill (SIGKILL) the whole group", worker.pid) - os.killpg(os.getpgid(worker.pid), signal.SIGKILL) - else: - _kill(worker.pid, SIGKILL) - except OSError: - pass - - def handle_timeouts(self): - t_hard, t_soft = self.t_hard, self.t_soft - dirty = set() - on_soft_timeout = self.on_soft_timeout - on_hard_timeout = self.on_hard_timeout - - def _timed_out(start, timeout): - if not start or not timeout: - return False - if monotonic() >= start + timeout: - return True - - # Inner-loop - while self._state == RUN: - # Perform a shallow copy before iteration because keys can change. - # A deep copy fails (on shutdown) due to thread.lock objects. - # https://github.com/celery/billiard/issues/260 - cache = copy.copy(self.cache) - - # Remove dirty items not in cache anymore - if dirty: - dirty = set(k for k in dirty if k in cache) - - for i, job in cache.items(): - ack_time = job._time_accepted - soft_timeout = job._soft_timeout - if soft_timeout is None: - soft_timeout = t_soft - hard_timeout = job._timeout - if hard_timeout is None: - hard_timeout = t_hard - if _timed_out(ack_time, hard_timeout): - on_hard_timeout(job) - elif i not in dirty and _timed_out(ack_time, soft_timeout): - on_soft_timeout(job) - dirty.add(i) - yield - - def body(self): - while self._state == RUN: - try: - for _ in self.handle_timeouts(): - time.sleep(1.0) # don't spin - except CoroStop: - break - debug('timeout handler exiting') - - def handle_event(self, *args): - if self._it is None: - self._it = self.handle_timeouts() - try: - next(self._it) - except StopIteration: - self._it = None - - -class ResultHandler(PoolThread): - - def __init__(self, outqueue, get, cache, poll, - join_exited_workers, putlock, restart_state, - check_timeouts, on_job_ready, on_ready_counters=None): - self.outqueue = outqueue - self.get = get - self.cache = cache - self.poll = poll - self.join_exited_workers = join_exited_workers - self.putlock = putlock - self.restart_state = restart_state - self._it = None - self._shutdown_complete = False - self.check_timeouts = check_timeouts - self.on_job_ready = on_job_ready - self.on_ready_counters = on_ready_counters - self._make_methods() - super(ResultHandler, self).__init__() - - def on_stop_not_started(self): - # used when pool started without result handler thread. - self.finish_at_shutdown(handle_timeouts=True) - - def _make_methods(self): - cache = self.cache - putlock = self.putlock - restart_state = self.restart_state - on_job_ready = self.on_job_ready - - def on_ack(job, i, time_accepted, pid, synqW_fd): - restart_state.R = 0 - try: - cache[job]._ack(i, time_accepted, pid, synqW_fd) - except (KeyError, AttributeError): - # Object gone or doesn't support _ack (e.g. IMAPIterator). - pass - - def on_ready(job, i, obj, inqW_fd): - if on_job_ready is not None: - on_job_ready(job, i, obj, inqW_fd) - try: - item = cache[job] - except KeyError: - return - - if self.on_ready_counters: - worker_pid = next(iter(item.worker_pids()), None) - if worker_pid and worker_pid in self.on_ready_counters: - on_ready_counter = self.on_ready_counters[worker_pid] - with on_ready_counter.get_lock(): - on_ready_counter.value += 1 - - if not item.ready(): - if putlock is not None: - putlock.release() - try: - item._set(i, obj) - except KeyError: - pass - - def on_death(pid, exitcode): - try: - os.kill(pid, TERM_SIGNAL) - except OSError as exc: - if get_errno(exc) != errno.ESRCH: - raise - - state_handlers = self.state_handlers = { - ACK: on_ack, READY: on_ready, DEATH: on_death - } - - def on_state_change(task): - state, args = task - try: - state_handlers[state](*args) - except KeyError: - debug("Unknown job state: %s (args=%s)", state, args) - self.on_state_change = on_state_change - - def _process_result(self, timeout=1.0): - poll = self.poll - on_state_change = self.on_state_change - - while 1: - try: - ready, task = poll(timeout) - except (IOError, EOFError) as exc: - debug('result handler got %r -- exiting', exc) - raise CoroStop() - - if self._state: - assert self._state == TERMINATE - debug('result handler found thread._state=TERMINATE') - raise CoroStop() - - if ready: - if task is None: - debug('result handler got sentinel') - raise CoroStop() - on_state_change(task) - if timeout != 0: # blocking - break - else: - break - yield - - def handle_event(self, fileno=None, events=None): - if self._state == RUN: - if self._it is None: - self._it = self._process_result(0) # non-blocking - try: - next(self._it) - except (StopIteration, CoroStop): - self._it = None - - def body(self): - debug('result handler starting') - try: - while self._state == RUN: - try: - for _ in self._process_result(1.0): # blocking - pass - except CoroStop: - break - finally: - self.finish_at_shutdown() - - def finish_at_shutdown(self, handle_timeouts=False): - self._shutdown_complete = True - get = self.get - outqueue = self.outqueue - cache = self.cache - poll = self.poll - join_exited_workers = self.join_exited_workers - check_timeouts = self.check_timeouts - on_state_change = self.on_state_change - - time_terminate = None - while cache and self._state != TERMINATE: - if check_timeouts is not None: - check_timeouts() - try: - ready, task = poll(1.0) - except (IOError, EOFError) as exc: - debug('result handler got %r -- exiting', exc) - return - - if ready: - if task is None: - debug('result handler ignoring extra sentinel') - continue - - on_state_change(task) - try: - join_exited_workers(shutdown=True) - except WorkersJoined: - now = monotonic() - if not time_terminate: - time_terminate = now - else: - if now - time_terminate > 5.0: - debug('result handler exiting: timed out') - break - debug('result handler: all workers terminated, ' - 'timeout in %ss', - abs(min(now - time_terminate - 5.0, 0))) - - if hasattr(outqueue, '_reader'): - debug('ensuring that outqueue is not full') - # If we don't make room available in outqueue then - # attempts to add the sentinel (None) to outqueue may - # block. There is guaranteed to be no more than 2 sentinels. - try: - for i in range(10): - if not outqueue._reader.poll(): - break - get() - except (IOError, EOFError): - pass - - debug('result handler exiting: len(cache)=%s, thread._state=%s', - len(cache), self._state) - +class _PoolCache(dict): + """ + Class that implements a cache for the Pool class that will notify + the pool management threads every time the cache is emptied. The + notification is done by the use of a queue that is provided when + instantiating the cache. + """ + def __init__(self, /, *args, notifier=None, **kwds): + self.notifier = notifier + super().__init__(*args, **kwds) + + def __delitem__(self, item): + super().__delitem__(item) + + # Notify that the cache is empty. This is important because the + # pool keeps maintaining workers until the cache gets drained. This + # eliminates a race condition in which a task is finished after the + # the pool's _handle_workers method has enter another iteration of the + # loop. In this situation, the only event that can wake up the pool + # is the cache to be emptied (no more tasks available). + if not self: + self.notifier.put(None) class Pool(object): ''' Class which supports an async version of applying functions to arguments. ''' _wrap_exception = True - Worker = Worker - Supervisor = Supervisor - TaskHandler = TaskHandler - TimeoutHandler = TimeoutHandler - ResultHandler = ResultHandler - SoftTimeLimitExceeded = SoftTimeLimitExceeded + + @staticmethod + def Process(ctx, *args, **kwds): + return ctx.Process(*args, **kwds) def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, timeout=None, soft_timeout=None, - lost_worker_timeout=None, - max_restarts=None, max_restart_freq=1, - on_process_up=None, - on_process_down=None, - on_timeout_set=None, - on_timeout_cancel=None, - threads=True, - semaphore=None, - putlocks=False, - allow_restart=False, - synack=False, - on_process_exit=None, - context=None, - max_memory_per_child=None, - enable_timeouts=False, - **kwargs): + maxtasksperchild=None, context=None): + # Attributes initialized early to make sure that they exist in + # __del__() if __init__() raises an exception + self._pool = [] + self._state = INIT + self._ctx = context or get_context() - self.synack = synack self._setup_queues() - self._taskqueue = Queue() - self._cache = {} - self._state = RUN - self.timeout = timeout - self.soft_timeout = soft_timeout + self._taskqueue = queue.SimpleQueue() + # The _change_notifier queue exist to wake up self._handle_workers() + # when the cache (self._cache) is empty or when there is a change in + # the _state variable of the thread that runs _handle_workers. + self._change_notifier = self._ctx.SimpleQueue() + self._cache = _PoolCache(notifier=self._change_notifier) self._maxtasksperchild = maxtasksperchild - self._max_memory_per_child = max_memory_per_child self._initializer = initializer self._initargs = initargs - self._on_process_exit = on_process_exit - self.lost_worker_timeout = lost_worker_timeout or LOST_WORKER_TIMEOUT - self.on_process_up = on_process_up - self.on_process_down = on_process_down - self.on_timeout_set = on_timeout_set - self.on_timeout_cancel = on_timeout_cancel - self.threads = threads - self.readers = {} - self.allow_restart = allow_restart - - self.enable_timeouts = bool( - enable_timeouts or - self.timeout is not None or - self.soft_timeout is not None - ) - - if soft_timeout and SIG_SOFT_TIMEOUT is None: - warnings.warn(UserWarning( - "Soft timeouts are not supported: " - "on this platform: It does not have the SIGUSR1 signal.", - )) - soft_timeout = None - self._processes = self.cpu_count() if processes is None else processes - self.max_restarts = max_restarts or round(self._processes * 100) - self.restart_state = restart_state(max_restarts, max_restart_freq or 1) + if processes is None: + processes = os.cpu_count() or 1 + if processes < 1: + raise ValueError("Number of processes must be at least 1") if initializer is not None and not callable(initializer): raise TypeError('initializer must be a callable') - if on_process_exit is not None and not callable(on_process_exit): - raise TypeError('on_process_exit must be callable') + self._processes = processes + try: + self._repopulate_pool() + except Exception: + for p in self._pool: + if p.exitcode is None: + p.terminate() + for p in self._pool: + p.join() + raise - self._Process = self._ctx.Process + sentinels = self._get_sentinels() - self._pool = [] - self._poolctrl = {} - self._on_ready_counters = {} - self.putlocks = putlocks - self._putlock = semaphore or LaxBoundedSemaphore(self._processes) - for i in range(self._processes): - self._create_worker_process(i) - - self._worker_handler = self.Supervisor(self) - if threads: - self._worker_handler.start() - - self._task_handler = self.TaskHandler(self._taskqueue, - self._quick_put, - self._outqueue, - self._pool, - self._cache) - if threads: - self._task_handler.start() - - self.check_timeouts = None - - # Thread killing timedout jobs. - if self.enable_timeouts: - self._timeout_handler = self.TimeoutHandler( - self._pool, self._cache, - self.soft_timeout, self.timeout, + self._worker_handler = threading.Thread( + target=Pool._handle_workers, + args=(self._cache, self._taskqueue, self._ctx, self.Process, + self._processes, self._pool, self._inqueue, self._outqueue, + self._initializer, self._initargs, self._maxtasksperchild, + self._wrap_exception, sentinels, self._change_notifier) ) - self._timeout_handler_mutex = Lock() - self._timeout_handler_started = False - self._start_timeout_handler() - # If running without threads, we need to check for timeouts - # while waiting for unfinished work at shutdown. - if not threads: - self.check_timeouts = self._timeout_handler.handle_event - else: - self._timeout_handler = None - self._timeout_handler_started = False - self._timeout_handler_mutex = None + self._worker_handler.daemon = True + self._worker_handler._state = RUN + self._worker_handler.start() - # Thread processing results in the outqueue. - self._result_handler = self.create_result_handler() - self.handle_result_event = self._result_handler.handle_event - if threads: - self._result_handler.start() + self._task_handler = threading.Thread( + target=Pool._handle_tasks, + args=(self._taskqueue, self._quick_put, self._outqueue, + self._pool, self._cache) + ) + self._task_handler.daemon = True + self._task_handler._state = RUN + self._task_handler.start() - self._terminate = Finalize( - self, self._terminate_pool, - args=(self._taskqueue, self._inqueue, self._outqueue, - self._pool, self._worker_handler, self._task_handler, - self._result_handler, self._cache, - self._timeout_handler, - self._help_stuff_finish_args()), - exitpriority=15, - ) + self._result_handler = threading.Thread( + target=Pool._handle_results, + args=(self._outqueue, self._quick_get, self._cache) + ) + self._result_handler.daemon = True + self._result_handler._state = RUN + self._result_handler.start() - def Process(self, *args, **kwds): - return self._Process(*args, **kwds) + self._terminate = util.Finalize( + self, self._terminate_pool, + args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, + self._change_notifier, self._worker_handler, self._task_handler, + self._result_handler, self._cache), + exitpriority=15 + ) + self._state = RUN - def WorkerProcess(self, worker): - return worker.contribute_to_object(self.Process(target=worker)) + # Copy globals as function locals to make sure that they are available + # during Python shutdown when the Pool is destroyed. + def __del__(self, _warn=warnings.warn, RUN=RUN): + if self._state == RUN: + _warn(f"unclosed running multiprocessing pool {self!r}", + ResourceWarning, source=self) + if getattr(self, '_change_notifier', None) is not None: + self._change_notifier.put(None) - def create_result_handler(self, **extra_kwargs): - return self.ResultHandler( - self._outqueue, self._quick_get, self._cache, - self._poll_result, self._join_exited_workers, - self._putlock, self.restart_state, self.check_timeouts, - self.on_job_ready, on_ready_counters=self._on_ready_counters, - **extra_kwargs - ) + def __repr__(self): + cls = self.__class__ + return (f'<{cls.__module__}.{cls.__qualname__} ' + f'state={self._state} ' + f'pool_size={len(self._pool)}>') - def on_job_ready(self, job, i, obj, inqW_fd): - pass + def _get_sentinels(self): + task_queue_sentinels = [self._outqueue._reader] + self_notifier_sentinels = [self._change_notifier._reader] + return [*task_queue_sentinels, *self_notifier_sentinels] - def _help_stuff_finish_args(self): - return self._inqueue, self._task_handler, self._pool + @staticmethod + def _get_worker_sentinels(workers): + return [worker.sentinel for worker in + workers if hasattr(worker, "sentinel")] - def cpu_count(self): - try: - return cpu_count() - except NotImplementedError: - return 1 - - def handle_result_event(self, *args): - return self._result_handler.handle_event(*args) - - def _process_register_queues(self, worker, queues): - pass - - def _process_by_pid(self, pid): - return next(( - (proc, i) for i, proc in enumerate(self._pool) - if proc.pid == pid - ), (None, None)) - - def get_process_queues(self): - return self._inqueue, self._outqueue, None - - def _create_worker_process(self, i): - sentinel = self._ctx.Event() if self.allow_restart else None - inq, outq, synq = self.get_process_queues() - on_ready_counter = self._ctx.Value('i') - w = self.WorkerProcess(self.Worker( - inq, outq, synq, self._initializer, self._initargs, - self._maxtasksperchild, sentinel, self._on_process_exit, - # Need to handle all signals if using the ipc semaphore, - # to make sure the semaphore is released. - sigprotection=self.threads, - wrap_exception=self._wrap_exception, - max_memory_per_child=self._max_memory_per_child, - on_ready_counter=on_ready_counter, - )) - self._pool.append(w) - self._process_register_queues(w, (inq, outq, synq)) - w.name = w.name.replace('Process', 'PoolWorker') - w.daemon = True - w.index = i - w.start() - self._poolctrl[w.pid] = sentinel - self._on_ready_counters[w.pid] = on_ready_counter - if self.on_process_up: - self.on_process_up(w) - return w - - def process_flush_queues(self, worker): - pass - - def _join_exited_workers(self, shutdown=False): - """Cleanup after any worker processes which have exited due to - reaching their specified lifetime. Returns True if any workers were - cleaned up. + @staticmethod + def _join_exited_workers(pool): + """Cleanup after any worker processes which have exited due to reaching + their specified lifetime. Returns True if any workers were cleaned up. """ - now = None - # The worker may have published a result before being terminated, - # but we have no way to accurately tell if it did. So we wait for - # _lost_worker_timeout seconds before we mark the job with - # WorkerLostError. - for job in [job for job in list(self._cache.values()) - if not job.ready() and job._worker_lost]: - now = now or monotonic() - lost_time, lost_ret = job._worker_lost - if now - lost_time > job._lost_worker_timeout: - self.mark_as_worker_lost(job, lost_ret) - - if shutdown and not len(self._pool): - raise WorkersJoined() - - cleaned, exitcodes = {}, {} - for i in reversed(range(len(self._pool))): - worker = self._pool[i] - exitcode = worker.exitcode - popen = worker._popen - if popen is None or exitcode is not None: + cleaned = False + for i in reversed(range(len(pool))): + worker = pool[i] + if worker.exitcode is not None: # worker exited - debug('Supervisor: cleaning up worker %d', i) - if popen is not None: - worker.join() - debug('Supervisor: worked %d joined', i) - cleaned[worker.pid] = worker - exitcodes[worker.pid] = exitcode - if exitcode not in (EX_OK, EX_RECYCLE) and \ - not getattr(worker, '_controlled_termination', False): - error( - 'Process %r pid:%r exited with %r', - worker.name, worker.pid, human_status(exitcode), - exc_info=0, - ) - self.process_flush_queues(worker) - del self._pool[i] - del self._poolctrl[worker.pid] - del self._on_ready_counters[worker.pid] - if cleaned: - all_pids = [w.pid for w in self._pool] - for job in list(self._cache.values()): - acked_by_gone = next( - (pid for pid in job.worker_pids() - if pid in cleaned or pid not in all_pids), - None - ) - # already accepted by process - if acked_by_gone: - self.on_job_process_down(job, acked_by_gone) - if not job.ready(): - exitcode = exitcodes.get(acked_by_gone) or 0 - proc = cleaned.get(acked_by_gone) - if proc and getattr(proc, '_job_terminated', False): - job._set_terminated(exitcode) - else: - self.on_job_process_lost( - job, acked_by_gone, exitcode, - ) - else: - # started writing to - write_to = job._write_to - # was scheduled to write to - sched_for = job._scheduled_for - - if write_to and not write_to._is_alive(): - self.on_job_process_down(job, write_to.pid) - elif sched_for and not sched_for._is_alive(): - self.on_job_process_down(job, sched_for.pid) - - for worker in values(cleaned): - if self.on_process_down: - if not shutdown: - self._process_cleanup_queues(worker) - self.on_process_down(worker) - return list(exitcodes.values()) - return [] - - def on_partial_read(self, job, worker): - pass - - def _process_cleanup_queues(self, worker): - pass - - def on_job_process_down(self, job, pid_gone): - pass + util.debug('cleaning up worker %d' % i) + worker.join() + cleaned = True + del pool[i] + return cleaned + + def _repopulate_pool(self): + return self._repopulate_pool_static(self._ctx, self.Process, + self._processes, + self._pool, self._inqueue, + self._outqueue, self._initializer, + self._initargs, + self._maxtasksperchild, + self._wrap_exception) - def on_job_process_lost(self, job, pid, exitcode): - job._worker_lost = (monotonic(), exitcode) - - def mark_as_worker_lost(self, job, exitcode): - try: - raise WorkerLostError( - 'Worker exited prematurely: {0}.'.format( - human_status(exitcode)), - ) - except WorkerLostError: - job._set(None, (False, ExceptionInfo())) - else: # pragma: no cover - pass - - def __enter__(self): - return self - - def __exit__(self, *exc_info): - return self.terminate() - - def on_grow(self, n): - pass - - def on_shrink(self, n): - pass - - def shrink(self, n=1): - for i, worker in enumerate(self._iterinactive()): - self._processes -= 1 - if self._putlock: - self._putlock.shrink() - worker.terminate_controlled() - self.on_shrink(1) - if i >= n - 1: - break - else: - raise ValueError("Can't shrink pool. All processes busy!") - - def grow(self, n=1): - for i in range(n): - self._processes += 1 - if self._putlock: - self._putlock.grow() - self.on_grow(n) - - def _iterinactive(self): - for worker in self._pool: - if not self._worker_active(worker): - yield worker - - def _worker_active(self, worker): - for job in values(self._cache): - if worker.pid in job.worker_pids(): - return True - return False - - def _repopulate_pool(self, exitcodes): + @staticmethod + def _repopulate_pool_static(ctx, Process, processes, pool, inqueue, + outqueue, initializer, initargs, + maxtasksperchild, wrap_exception): """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ - for i in range(self._processes - len(self._pool)): - if self._state != RUN: - return - try: - if exitcodes and exitcodes[i] not in (EX_OK, EX_RECYCLE): - self.restart_state.step() - except IndexError: - self.restart_state.step() - self._create_worker_process(self._avail_index()) - debug('added worker') - - def _avail_index(self): - assert len(self._pool) < self._processes - indices = set(p.index for p in self._pool) - return next(i for i in range(self._processes) if i not in indices) - - def did_start_ok(self): - return not self._join_exited_workers() + for i in range(processes - len(pool)): + w = Process(ctx, target=worker, + args=(inqueue, outqueue, + initializer, + initargs, maxtasksperchild, + wrap_exception)) + w.name = w.name.replace('Process', 'PoolWorker') + w.daemon = True + w.start() + pool.append(w) + util.debug('added worker') - def _maintain_pool(self): - """"Clean up any exited workers and start replacements for them. + @staticmethod + def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue, + initializer, initargs, maxtasksperchild, + wrap_exception): + """Clean up any exited workers and start replacements for them. """ - joined = self._join_exited_workers() - self._repopulate_pool(joined) - for i in range(len(joined)): - if self._putlock is not None: - self._putlock.release() - - def maintain_pool(self): - if self._worker_handler._state == RUN and self._state == RUN: - try: - self._maintain_pool() - except RestartFreqExceeded: - self.close() - self.join() - raise - except OSError as exc: - if get_errno(exc) == errno.ENOMEM: - reraise(MemoryError, - MemoryError(str(exc)), - sys.exc_info()[2]) - raise + if Pool._join_exited_workers(pool): + Pool._repopulate_pool_static(ctx, Process, processes, pool, + inqueue, outqueue, initializer, + initargs, maxtasksperchild, + wrap_exception) def _setup_queues(self): self._inqueue = self._ctx.SimpleQueue() @@ -1366,27 +345,23 @@ def _setup_queues(self): self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv - def _poll_result(timeout): - if self._outqueue._reader.poll(timeout): - return True, self._quick_get() - return False, None - self._poll_result = _poll_result - - def _start_timeout_handler(self): - # ensure more than one thread does not start the timeout handler - # thread at once. - if self.threads and self._timeout_handler is not None: - with self._timeout_handler_mutex: - if not self._timeout_handler_started: - self._timeout_handler_started = True - self._timeout_handler.start() + def _check_running(self): + if self._state != RUN: + raise ValueError("Pool not running") def apply(self, func, args=(), kwds={}): ''' - Equivalent of `func(*args, **kwargs)`. + Equivalent of `func(*args, **kwds)`. + Pool must be running. ''' - if self._state == RUN: - return self.apply_async(func, args, kwds).get() + return self.apply_async(func, args, kwds).get() + + def map(self, func, iterable, chunksize=None): + ''' + Apply `func` to each element in `iterable`, collecting the results + in a list that is returned. + ''' + return self._map_async(func, iterable, mapstar, chunksize).get() def starmap(self, func, iterable, chunksize=None): ''' @@ -1394,173 +369,108 @@ def starmap(self, func, iterable, chunksize=None): be iterables as well and will be unpacked as arguments. Hence `func` and (a, b) becomes func(a, b). ''' - if self._state == RUN: - return self._map_async(func, iterable, - starmapstar, chunksize).get() + return self._map_async(func, iterable, starmapstar, chunksize).get() - def starmap_async(self, func, iterable, chunksize=None, - callback=None, error_callback=None): + def starmap_async(self, func, iterable, chunksize=None, callback=None, + error_callback=None): ''' Asynchronous version of `starmap()` method. ''' - if self._state == RUN: - return self._map_async(func, iterable, starmapstar, chunksize, - callback, error_callback) + return self._map_async(func, iterable, starmapstar, chunksize, + callback, error_callback) - def map(self, func, iterable, chunksize=None): - ''' - Apply `func` to each element in `iterable`, collecting the results - in a list that is returned. - ''' - if self._state == RUN: - return self.map_async(func, iterable, chunksize).get() + def _guarded_task_generation(self, result_job, func, iterable): + '''Provides a generator of tasks for imap and imap_unordered with + appropriate handling for iterables which throw exceptions during + iteration.''' + try: + i = -1 + for i, x in enumerate(iterable): + yield (result_job, i, func, (x,), {}) + except Exception as e: + yield (result_job, i+1, _helper_reraises_exception, (e,), {}) - def imap(self, func, iterable, chunksize=1, lost_worker_timeout=None): + def imap(self, func, iterable, chunksize=1): ''' Equivalent of `map()` -- can be MUCH slower than `Pool.map()`. ''' - if self._state != RUN: - return - lost_worker_timeout = lost_worker_timeout or self.lost_worker_timeout + self._check_running() if chunksize == 1: - result = IMapIterator(self._cache, - lost_worker_timeout=lost_worker_timeout) - self._taskqueue.put(( - ((TASK, (result._job, i, func, (x,), {})) - for i, x in enumerate(iterable)), - result._set_length, - )) + result = IMapIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, func, iterable), + result._set_length + )) return result else: - assert chunksize > 1 + if chunksize < 1: + raise ValueError( + "Chunksize must be 1+, not {0:n}".format( + chunksize)) task_batches = Pool._get_tasks(func, iterable, chunksize) - result = IMapIterator(self._cache, - lost_worker_timeout=lost_worker_timeout) - self._taskqueue.put(( - ((TASK, (result._job, i, mapstar, (x,), {})) - for i, x in enumerate(task_batches)), - result._set_length, - )) + result = IMapIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, + mapstar, + task_batches), + result._set_length + )) return (item for chunk in result for item in chunk) - def imap_unordered(self, func, iterable, chunksize=1, - lost_worker_timeout=None): + def imap_unordered(self, func, iterable, chunksize=1): ''' Like `imap()` method but ordering of results is arbitrary. ''' - if self._state != RUN: - return - lost_worker_timeout = lost_worker_timeout or self.lost_worker_timeout + self._check_running() if chunksize == 1: - result = IMapUnorderedIterator( - self._cache, lost_worker_timeout=lost_worker_timeout, - ) - self._taskqueue.put(( - ((TASK, (result._job, i, func, (x,), {})) - for i, x in enumerate(iterable)), - result._set_length, - )) + result = IMapUnorderedIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, func, iterable), + result._set_length + )) return result else: - assert chunksize > 1 + if chunksize < 1: + raise ValueError( + "Chunksize must be 1+, not {0!r}".format(chunksize)) task_batches = Pool._get_tasks(func, iterable, chunksize) - result = IMapUnorderedIterator( - self._cache, lost_worker_timeout=lost_worker_timeout, - ) - self._taskqueue.put(( - ((TASK, (result._job, i, mapstar, (x,), {})) - for i, x in enumerate(task_batches)), - result._set_length, - )) + result = IMapUnorderedIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, + mapstar, + task_batches), + result._set_length + )) return (item for chunk in result for item in chunk) - def apply_async(self, func, args=(), kwds={}, - callback=None, error_callback=None, accept_callback=None, - timeout_callback=None, waitforslot=None, - soft_timeout=None, timeout=None, lost_worker_timeout=None, - callbacks_propagate=(), - correlation_id=None): + def apply_async(self, func, args=(), kwds={}, callback=None, + error_callback=None): ''' - Asynchronous equivalent of `apply()` method. - - Callback is called when the functions return value is ready. - The accept callback is called when the job is accepted to be executed. - - Simplified the flow is like this: - - >>> def apply_async(func, args, kwds, callback, accept_callback): - ... if accept_callback: - ... accept_callback() - ... retval = func(*args, **kwds) - ... if callback: - ... callback(retval) - + Asynchronous version of `apply()` method. ''' - if self._state != RUN: - return - soft_timeout = soft_timeout or self.soft_timeout - timeout = timeout or self.timeout - lost_worker_timeout = lost_worker_timeout or self.lost_worker_timeout - if soft_timeout and SIG_SOFT_TIMEOUT is None: - warnings.warn(UserWarning( - "Soft timeouts are not supported: " - "on this platform: It does not have the SIGUSR1 signal.", - )) - soft_timeout = None - if self._state == RUN: - waitforslot = self.putlocks if waitforslot is None else waitforslot - if waitforslot and self._putlock is not None: - self._putlock.acquire() - result = ApplyResult( - self._cache, callback, accept_callback, timeout_callback, - error_callback, soft_timeout, timeout, lost_worker_timeout, - on_timeout_set=self.on_timeout_set, - on_timeout_cancel=self.on_timeout_cancel, - callbacks_propagate=callbacks_propagate, - send_ack=self.send_ack if self.synack else None, - correlation_id=correlation_id, - ) - if timeout or soft_timeout: - # start the timeout handler thread when required. - self._start_timeout_handler() - if self.threads: - self._taskqueue.put(([(TASK, (result._job, None, - func, args, kwds))], None)) - else: - self._quick_put((TASK, (result._job, None, func, args, kwds))) - return result - - def send_ack(self, response, job, i, fd): - pass - - def terminate_job(self, pid, sig=None): - proc, _ = self._process_by_pid(pid) - if proc is not None: - try: - _kill(pid, sig or TERM_SIGNAL) - except OSError as exc: - if get_errno(exc) != errno.ESRCH: - raise - else: - proc._controlled_termination = True - proc._job_terminated = True + self._check_running() + result = ApplyResult(self, callback, error_callback) + self._taskqueue.put(([(result._job, 0, func, args, kwds)], None)) + return result - def map_async(self, func, iterable, chunksize=None, - callback=None, error_callback=None): + def map_async(self, func, iterable, chunksize=None, callback=None, + error_callback=None): ''' - Asynchronous equivalent of `map()` method. + Asynchronous version of `map()` method. ''' - return self._map_async( - func, iterable, mapstar, chunksize, callback, error_callback, - ) + return self._map_async(func, iterable, mapstar, chunksize, callback, + error_callback) - def _map_async(self, func, iterable, mapper, chunksize=None, - callback=None, error_callback=None): + def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, + error_callback=None): ''' Helper function to implement map, starmap and their async counterparts. ''' - if self._state != RUN: - return + self._check_running() if not hasattr(iterable, '__len__'): iterable = list(iterable) @@ -1572,12 +482,151 @@ def _map_async(self, func, iterable, mapper, chunksize=None, chunksize = 0 task_batches = Pool._get_tasks(func, iterable, chunksize) - result = MapResult(self._cache, chunksize, len(iterable), callback, + result = MapResult(self, chunksize, len(iterable), callback, error_callback=error_callback) - self._taskqueue.put((((TASK, (result._job, i, mapper, (x,), {})) - for i, x in enumerate(task_batches)), None)) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, + mapper, + task_batches), + None + ) + ) return result + @staticmethod + def _wait_for_updates(sentinels, change_notifier, timeout=None): + wait(sentinels, timeout=timeout) + while not change_notifier.empty(): + change_notifier.get() + + @classmethod + def _handle_workers(cls, cache, taskqueue, ctx, Process, processes, + pool, inqueue, outqueue, initializer, initargs, + maxtasksperchild, wrap_exception, sentinels, + change_notifier): + thread = threading.current_thread() + + # Keep maintaining workers until the cache gets drained, unless the pool + # is terminated. + while thread._state == RUN or (cache and thread._state != TERMINATE): + cls._maintain_pool(ctx, Process, processes, pool, inqueue, + outqueue, initializer, initargs, + maxtasksperchild, wrap_exception) + + current_sentinels = [*cls._get_worker_sentinels(pool), *sentinels] + + cls._wait_for_updates(current_sentinels, change_notifier) + # send sentinel to stop workers + taskqueue.put(None) + util.debug('worker handler exiting') + + @staticmethod + def _handle_tasks(taskqueue, put, outqueue, pool, cache): + thread = threading.current_thread() + + for taskseq, set_length in iter(taskqueue.get, None): + task = None + try: + # iterating taskseq cannot fail + for task in taskseq: + if thread._state != RUN: + util.debug('task handler found thread._state != RUN') + break + try: + put(task) + except Exception as e: + job, idx = task[:2] + try: + cache[job]._set(idx, (False, e)) + except KeyError: + pass + else: + if set_length: + util.debug('doing set_length()') + idx = task[1] if task else -1 + set_length(idx + 1) + continue + break + finally: + task = taskseq = job = None + else: + util.debug('task handler got sentinel') + + try: + # tell result handler to finish when cache is empty + util.debug('task handler sending sentinel to result handler') + outqueue.put(None) + + # tell workers there is no more work + util.debug('task handler sending sentinel to workers') + for p in pool: + put(None) + except OSError: + util.debug('task handler got OSError when sending sentinels') + + util.debug('task handler exiting') + + @staticmethod + def _handle_results(outqueue, get, cache): + thread = threading.current_thread() + + while 1: + try: + task = get() + except (OSError, EOFError): + util.debug('result handler got EOFError/OSError -- exiting') + return + + if thread._state != RUN: + assert thread._state == TERMINATE, "Thread not in TERMINATE" + util.debug('result handler found thread._state=TERMINATE') + break + + if task is None: + util.debug('result handler got sentinel') + break + + job, i, obj = task + try: + cache[job]._set(i, obj) + except KeyError: + pass + task = job = obj = None + + while cache and thread._state != TERMINATE: + try: + task = get() + except (OSError, EOFError): + util.debug('result handler got EOFError/OSError -- exiting') + return + + if task is None: + util.debug('result handler ignoring extra sentinel') + continue + job, i, obj = task + try: + cache[job]._set(i, obj) + except KeyError: + pass + task = job = obj = None + + if hasattr(outqueue, '_reader'): + util.debug('ensuring that outqueue is not full') + # If we don't make room available in outqueue then + # attempts to add the sentinel (None) to outqueue may + # block. There is guaranteed to be no more than 2 sentinels. + try: + for i in range(10): + if not outqueue._reader.poll(): + break + get() + except (OSError, EOFError): + pass + + util.debug('result handler exiting: len(cache)=%s, thread._state=%s', + len(cache), thread._state) + @staticmethod def _get_tasks(func, it, size): it = iter(it) @@ -1589,190 +638,126 @@ def _get_tasks(func, it, size): def __reduce__(self): raise NotImplementedError( - 'pool objects cannot be passed between processes or pickled', - ) + 'pool objects cannot be passed between processes or pickled' + ) def close(self): - debug('closing pool') + util.debug('closing pool') if self._state == RUN: self._state = CLOSE - if self._putlock: - self._putlock.clear() - self._worker_handler.close() - self._taskqueue.put(None) - stop_if_not_current(self._worker_handler) + self._worker_handler._state = CLOSE + self._change_notifier.put(None) def terminate(self): - debug('terminating pool') + util.debug('terminating pool') self._state = TERMINATE - self._worker_handler.terminate() self._terminate() - @staticmethod - def _stop_task_handler(task_handler): - stop_if_not_current(task_handler) - def join(self): - assert self._state in (CLOSE, TERMINATE) - debug('joining worker handler') - stop_if_not_current(self._worker_handler) - debug('joining task handler') - self._stop_task_handler(self._task_handler) - debug('joining result handler') - stop_if_not_current(self._result_handler) - debug('result handler joined') - for i, p in enumerate(self._pool): - debug('joining worker %s/%s (%r)', i + 1, len(self._pool), p) - if p._popen is not None: # process started? - p.join() - debug('pool join complete') - - def restart(self): - for e in values(self._poolctrl): - e.set() + util.debug('joining pool') + if self._state == RUN: + raise ValueError("Pool is still running") + elif self._state not in (CLOSE, TERMINATE): + raise ValueError("In unknown state") + self._worker_handler.join() + self._task_handler.join() + self._result_handler.join() + for p in self._pool: + p.join() @staticmethod - def _help_stuff_finish(inqueue, task_handler, _pool): + def _help_stuff_finish(inqueue, task_handler, size): # task_handler may be blocked trying to put items on inqueue - debug('removing tasks from inqueue until task handler finished') + util.debug('removing tasks from inqueue until task handler finished') inqueue._rlock.acquire() while task_handler.is_alive() and inqueue._reader.poll(): inqueue._reader.recv() time.sleep(0) @classmethod - def _set_result_sentinel(cls, outqueue, pool): - outqueue.put(None) - - @classmethod - def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, - worker_handler, task_handler, - result_handler, cache, timeout_handler, - help_stuff_finish_args): - + def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, change_notifier, + worker_handler, task_handler, result_handler, cache): # this is guaranteed to only be called once - debug('finalizing pool') + util.debug('finalizing pool') + + # Notify that the worker_handler state has been changed so the + # _handle_workers loop can be unblocked (and exited) in order to + # send the finalization sentinel all the workers. + worker_handler._state = TERMINATE + change_notifier.put(None) - worker_handler.terminate() + task_handler._state = TERMINATE - task_handler.terminate() - taskqueue.put(None) # sentinel + util.debug('helping task handler/workers to finish') + cls._help_stuff_finish(inqueue, task_handler, len(pool)) - debug('helping task handler/workers to finish') - cls._help_stuff_finish(*help_stuff_finish_args) + if (not result_handler.is_alive()) and (len(cache) != 0): + raise AssertionError( + "Cannot have cache with result_hander not alive") - result_handler.terminate() - cls._set_result_sentinel(outqueue, pool) + result_handler._state = TERMINATE + change_notifier.put(None) + outqueue.put(None) # sentinel - if timeout_handler is not None: - timeout_handler.terminate() + # We must wait for the worker handler to exit before terminating + # workers because we don't want workers to be restarted behind our back. + util.debug('joining worker handler') + if threading.current_thread() is not worker_handler: + worker_handler.join() - # Terminate workers which haven't already finished + # Terminate workers which haven't already finished. if pool and hasattr(pool[0], 'terminate'): - debug('terminating workers') + util.debug('terminating workers') for p in pool: - if p._is_alive(): + if p.exitcode is None: p.terminate() - debug('joining task handler') - cls._stop_task_handler(task_handler) + util.debug('joining task handler') + if threading.current_thread() is not task_handler: + task_handler.join() - debug('joining result handler') - result_handler.stop() - - if timeout_handler is not None: - debug('joining timeout handler') - timeout_handler.stop(TIMEOUT_MAX) + util.debug('joining result handler') + if threading.current_thread() is not result_handler: + result_handler.join() if pool and hasattr(pool[0], 'terminate'): - debug('joining pool workers') + util.debug('joining pool workers') for p in pool: if p.is_alive(): # worker has not yet exited - debug('cleaning up worker %d', p.pid) - if p._popen is not None: - p.join() - debug('pool workers joined') + util.debug('cleaning up worker %d' % p.pid) + p.join() + + def __enter__(self): + self._check_running() + return self - @property - def process_sentinels(self): - return [w._popen.sentinel for w in self._pool] + def __exit__(self, exc_type, exc_val, exc_tb): + self.terminate() # # Class whose instances are returned by `Pool.apply_async()` # - class ApplyResult(object): - _worker_lost = None - _write_to = None - _scheduled_for = None - - def __init__(self, cache, callback, accept_callback=None, - timeout_callback=None, error_callback=None, soft_timeout=None, - timeout=None, lost_worker_timeout=LOST_WORKER_TIMEOUT, - on_timeout_set=None, on_timeout_cancel=None, - callbacks_propagate=(), send_ack=None, - correlation_id=None): - self.correlation_id = correlation_id - self._mutex = Lock() + + def __init__(self, pool, callback, error_callback): + self._pool = pool self._event = threading.Event() self._job = next(job_counter) - self._cache = cache + self._cache = pool._cache self._callback = callback - self._accept_callback = accept_callback self._error_callback = error_callback - self._timeout_callback = timeout_callback - self._timeout = timeout - self._soft_timeout = soft_timeout - self._lost_worker_timeout = lost_worker_timeout - self._on_timeout_set = on_timeout_set - self._on_timeout_cancel = on_timeout_cancel - self._callbacks_propagate = callbacks_propagate or () - self._send_ack = send_ack - - self._accepted = False - self._cancelled = False - self._worker_pid = None - self._time_accepted = None - self._terminated = None - cache[self._job] = self - - def __repr__(self): - return '<%s: {id} ack:{ack} ready:{ready}>'.format( - self.__class__.__name__, - id=self._job, ack=self._accepted, ready=self.ready(), - ) + self._cache[self._job] = self def ready(self): - return self._event.isSet() - - def accepted(self): - return self._accepted + return self._event.is_set() def successful(self): - assert self.ready() + if not self.ready(): + raise ValueError("{0!r} not ready".format(self)) return self._success - def _cancel(self): - """Only works if synack is used.""" - self._cancelled = True - - def discard(self): - self._cache.pop(self._job, None) - - def terminate(self, signum): - self._terminated = signum - - def _set_terminated(self, signum=None): - try: - raise Terminated(-(signum or 0)) - except Terminated: - self._set(None, (False, ExceptionInfo())) - - def worker_pids(self): - return [self._worker_pid] if self._worker_pid else [] - def wait(self, timeout=None): self._event.wait(timeout) @@ -1783,160 +768,81 @@ def get(self, timeout=None): if self._success: return self._value else: - raise self._value.exception - - def safe_apply_callback(self, fun, *args, **kwargs): - if fun: - try: - fun(*args, **kwargs) - except self._callbacks_propagate: - raise - except Exception as exc: - error('Pool callback raised exception: %r', exc, - exc_info=1) - - def handle_timeout(self, soft=False): - if self._timeout_callback is not None: - self.safe_apply_callback( - self._timeout_callback, soft=soft, - timeout=self._soft_timeout if soft else self._timeout, - ) + raise self._value def _set(self, i, obj): - with self._mutex: - if self._on_timeout_cancel: - self._on_timeout_cancel(self) - self._success, self._value = obj - self._event.set() - if self._accepted: - # if not accepted yet, then the set message - # was received before the ack, which means - # the ack will remove the entry. - self._cache.pop(self._job, None) - - # apply callbacks last - if self._callback and self._success: - self.safe_apply_callback( - self._callback, self._value) - if (self._value is not None and - self._error_callback and not self._success): - self.safe_apply_callback( - self._error_callback, self._value) - - def _ack(self, i, time_accepted, pid, synqW_fd): - with self._mutex: - if self._cancelled and self._send_ack: - self._accepted = True - if synqW_fd: - return self._send_ack(NACK, pid, self._job, synqW_fd) - return - self._accepted = True - self._time_accepted = time_accepted - self._worker_pid = pid - if self.ready(): - # ack received after set() - self._cache.pop(self._job, None) - if self._on_timeout_set: - self._on_timeout_set(self, self._soft_timeout, self._timeout) - response = ACK - if self._accept_callback: - try: - self._accept_callback(pid, time_accepted) - except self._propagate_errors: - response = NACK - raise - except Exception: - response = NACK - # ignore other errors - finally: - if self._send_ack and synqW_fd: - return self._send_ack( - response, pid, self._job, synqW_fd - ) - if self._send_ack and synqW_fd: - self._send_ack(response, pid, self._job, synqW_fd) + self._success, self._value = obj + if self._callback and self._success: + self._callback(self._value) + if self._error_callback and not self._success: + self._error_callback(self._value) + self._event.set() + del self._cache[self._job] + self._pool = None + + __class_getitem__ = classmethod(types.GenericAlias) + +AsyncResult = ApplyResult # create alias -- see #17805 # # Class whose instances are returned by `Pool.map_async()` # - class MapResult(ApplyResult): - def __init__(self, cache, chunksize, length, callback, error_callback): - ApplyResult.__init__( - self, cache, callback, error_callback=error_callback, - ) + def __init__(self, pool, chunksize, length, callback, error_callback): + ApplyResult.__init__(self, pool, callback, + error_callback=error_callback) self._success = True - self._length = length self._value = [None] * length - self._accepted = [False] * length - self._worker_pid = [None] * length - self._time_accepted = [None] * length self._chunksize = chunksize if chunksize <= 0: self._number_left = 0 self._event.set() - del cache[self._job] + del self._cache[self._job] else: - self._number_left = length // chunksize + bool(length % chunksize) + self._number_left = length//chunksize + bool(length % chunksize) def _set(self, i, success_result): + self._number_left -= 1 success, result = success_result - if success: - self._value[i * self._chunksize:(i + 1) * self._chunksize] = result - self._number_left -= 1 + if success and self._success: + self._value[i*self._chunksize:(i+1)*self._chunksize] = result if self._number_left == 0: if self._callback: self._callback(self._value) - if self._accepted: - self._cache.pop(self._job, None) + del self._cache[self._job] self._event.set() + self._pool = None else: - self._success = False - self._value = result - if self._error_callback: - self._error_callback(self._value) - if self._accepted: - self._cache.pop(self._job, None) - self._event.set() - - def _ack(self, i, time_accepted, pid, *args): - start = i * self._chunksize - stop = min((i + 1) * self._chunksize, self._length) - for j in range(start, stop): - self._accepted[j] = True - self._worker_pid[j] = pid - self._time_accepted[j] = time_accepted - if self.ready(): - self._cache.pop(self._job, None) - - def accepted(self): - return all(self._accepted) - - def worker_pids(self): - return [pid for pid in self._worker_pid if pid] + if not success and self._success: + # only store first exception + self._success = False + self._value = result + if self._number_left == 0: + # only consider the result ready once all jobs are done + if self._error_callback: + self._error_callback(self._value) + del self._cache[self._job] + self._event.set() + self._pool = None # # Class whose instances are returned by `Pool.imap()` # - class IMapIterator(object): - _worker_lost = None - def __init__(self, cache, lost_worker_timeout=LOST_WORKER_TIMEOUT): + def __init__(self, pool): + self._pool = pool self._cond = threading.Condition(threading.Lock()) self._job = next(job_counter) - self._cache = cache - self._items = deque() + self._cache = pool._cache + self._items = collections.deque() self._index = 0 self._length = None - self._ready = False self._unsorted = {} - self._worker_pids = [] - self._lost_worker_timeout = lost_worker_timeout - cache[self._job] = self + self._cache[self._job] = self def __iter__(self): return self @@ -1947,21 +853,21 @@ def next(self, timeout=None): item = self._items.popleft() except IndexError: if self._index == self._length: - self._ready = True - raise StopIteration + self._pool = None + raise StopIteration from None self._cond.wait(timeout) try: item = self._items.popleft() except IndexError: if self._index == self._length: - self._ready = True - raise StopIteration - raise TimeoutError + self._pool = None + raise StopIteration from None + raise TimeoutError from None success, value = item if success: return value - raise Exception(value) + raise value __next__ = next # XXX @@ -1979,31 +885,21 @@ def _set(self, i, obj): self._unsorted[i] = obj if self._index == self._length: - self._ready = True del self._cache[self._job] + self._pool = None def _set_length(self, length): with self._cond: self._length = length if self._index == self._length: - self._ready = True self._cond.notify() del self._cache[self._job] - - def _ack(self, i, time_accepted, pid, *args): - self._worker_pids.append(pid) - - def ready(self): - return self._ready - - def worker_pids(self): - return self._worker_pids + self._pool = None # # Class whose instances are returned by `Pool.imap_unordered()` # - class IMapUnorderedIterator(IMapIterator): def _set(self, i, obj): @@ -2012,39 +908,47 @@ def _set(self, i, obj): self._index += 1 self._cond.notify() if self._index == self._length: - self._ready = True del self._cache[self._job] + self._pool = None # # # - class ThreadPool(Pool): + _wrap_exception = False - from .dummy import Process as DummyProcess - Process = DummyProcess + @staticmethod + def Process(ctx, *args, **kwds): + from .dummy import Process + return Process(*args, **kwds) def __init__(self, processes=None, initializer=None, initargs=()): Pool.__init__(self, processes, initializer, initargs) def _setup_queues(self): - self._inqueue = Queue() - self._outqueue = Queue() + self._inqueue = queue.SimpleQueue() + self._outqueue = queue.SimpleQueue() self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get - def _poll_result(timeout): - try: - return True, self._quick_get(timeout=timeout) - except Empty: - return False, None - self._poll_result = _poll_result + def _get_sentinels(self): + return [self._change_notifier._reader] @staticmethod - def _help_stuff_finish(inqueue, task_handler, pool): - # put sentinels at head of inqueue to make workers finish - with inqueue.not_empty: - inqueue.queue.clear() - inqueue.queue.extend([None] * len(pool)) - inqueue.not_empty.notify_all() + def _get_worker_sentinels(workers): + return [] + + @staticmethod + def _help_stuff_finish(inqueue, task_handler, size): + # drain inqueue, and put sentinels at its head to make workers finish + try: + while True: + inqueue.get(block=False) + except queue.Empty: + pass + for i in range(size): + inqueue.put(None) + + def _wait_for_updates(self, sentinels, change_notifier, timeout): + time.sleep(timeout) diff --git a/billiard/popen_fork.py b/billiard/popen_fork.py index 8e03f4fd..625981cf 100644 --- a/billiard/popen_fork.py +++ b/billiard/popen_fork.py @@ -1,10 +1,7 @@ -from __future__ import absolute_import - import os -import sys -import errno +import signal -from .common import TERM_SIGNAL +from . import util __all__ = ['Popen'] @@ -12,15 +9,13 @@ # Start child process using fork # - class Popen(object): method = 'fork' - sentinel = None def __init__(self, process_obj): - sys.stdout.flush() - sys.stderr.flush() + util._flush_std_streams() self.returncode = None + self.finalizer = None self._launch(process_obj) def duplicate_for_child(self, fd): @@ -28,64 +23,61 @@ def duplicate_for_child(self, fd): def poll(self, flag=os.WNOHANG): if self.returncode is None: - while True: - try: - pid, sts = os.waitpid(self.pid, flag) - except OSError as e: - if e.errno == errno.EINTR: - continue - # Child process not yet created. See #1731717 - # e.errno == errno.ECHILD == 10 - return None - else: - break + try: + pid, sts = os.waitpid(self.pid, flag) + except OSError: + # Child process not yet created. See #1731717 + # e.errno == errno.ECHILD == 10 + return None if pid == self.pid: - if os.WIFSIGNALED(sts): - self.returncode = -os.WTERMSIG(sts) - else: - assert os.WIFEXITED(sts) - self.returncode = os.WEXITSTATUS(sts) + self.returncode = os.waitstatus_to_exitcode(sts) return self.returncode def wait(self, timeout=None): if self.returncode is None: if timeout is not None: - from .connection import wait + from multiprocessing.connection import wait if not wait([self.sentinel], timeout): return None # This shouldn't block if wait() returned successfully. return self.poll(os.WNOHANG if timeout == 0.0 else 0) return self.returncode - def terminate(self): + def _send_signal(self, sig): if self.returncode is None: try: - os.kill(self.pid, TERM_SIGNAL) - except OSError as exc: - if getattr(exc, 'errno', None) != errno.ESRCH: - if self.wait(timeout=0.1) is None: - raise + os.kill(self.pid, sig) + except ProcessLookupError: + pass + except OSError: + if self.wait(timeout=0.1) is None: + raise + + def terminate(self): + self._send_signal(signal.SIGTERM) + + def kill(self): + self._send_signal(signal.SIGKILL) def _launch(self, process_obj): code = 1 parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() self.pid = os.fork() if self.pid == 0: try: os.close(parent_r) - if 'random' in sys.modules: - import random - random.seed() - code = process_obj._bootstrap() + os.close(parent_w) + code = process_obj._bootstrap(parent_sentinel=child_r) finally: os._exit(code) else: os.close(child_w) + os.close(child_r) + self.finalizer = util.Finalize(self, util.close_fds, + (parent_r, parent_w,)) self.sentinel = parent_r def close(self): - if self.sentinel is not None: - try: - os.close(self.sentinel) - finally: - self.sentinel = None + if self.finalizer is not None: + self.finalizer() diff --git a/billiard/popen_forkserver.py b/billiard/popen_forkserver.py index d4b18d5c..a56eb9bf 100644 --- a/billiard/popen_forkserver.py +++ b/billiard/popen_forkserver.py @@ -1,13 +1,14 @@ -from __future__ import absolute_import - import io import os -from . import reduction -from . import context +from .context import reduction, set_spawning_popen +if not reduction.HAVE_SEND_HANDLE: + raise ImportError('No support for sending fds between processes') from . import forkserver from . import popen_fork from . import spawn +from . import util + __all__ = ['Popen'] @@ -15,12 +16,9 @@ # Wrapper for an fd used while launching a process # - class _DupFd(object): - def __init__(self, ind): self.ind = ind - def detach(self): return forkserver.get_inherited_fds()[self.ind] @@ -28,14 +26,13 @@ def detach(self): # Start child process using a server process # - class Popen(popen_fork.Popen): method = 'forkserver' DupFd = _DupFd def __init__(self, process_obj): self._fds = [] - super(Popen, self).__init__(process_obj) + super().__init__(process_obj) def duplicate_for_child(self, fd): self._fds.append(fd) @@ -44,27 +41,34 @@ def duplicate_for_child(self, fd): def _launch(self, process_obj): prep_data = spawn.get_preparation_data(process_obj._name) buf = io.BytesIO() - context.set_spawning_popen(self) + set_spawning_popen(self) try: reduction.dump(prep_data, buf) reduction.dump(process_obj, buf) finally: - context.set_spawning_popen(None) + set_spawning_popen(None) self.sentinel, w = forkserver.connect_to_new_process(self._fds) - with io.open(w, 'wb', closefd=True) as f: + # Keep a duplicate of the data pipe's write end as a sentinel of the + # parent process used by the child process. + _parent_w = os.dup(w) + self.finalizer = util.Finalize(self, util.close_fds, + (_parent_w, self.sentinel)) + with open(w, 'wb', closefd=True) as f: f.write(buf.getbuffer()) - self.pid = forkserver.read_unsigned(self.sentinel) + self.pid = forkserver.read_signed(self.sentinel) def poll(self, flag=os.WNOHANG): if self.returncode is None: - from .connection import wait + from multiprocessing.connection import wait timeout = 0 if flag == os.WNOHANG else None if not wait([self.sentinel], timeout): return None try: - self.returncode = forkserver.read_unsigned(self.sentinel) + self.returncode = forkserver.read_signed(self.sentinel) except (OSError, EOFError): - # The process ended abnormally perhaps because of a signal + # This should not happen usually, but perhaps the forkserver + # process itself got killed self.returncode = 255 + return self.returncode diff --git a/billiard/popen_spawn_posix.py b/billiard/popen_spawn_posix.py index a3f1111d..24b86345 100644 --- a/billiard/popen_spawn_posix.py +++ b/billiard/popen_spawn_posix.py @@ -1,14 +1,10 @@ -from __future__ import absolute_import - import io import os -from . import context +from .context import reduction, set_spawning_popen from . import popen_fork -from . import reduction from . import spawn - -from .compat import spawnv_passfds +from . import util __all__ = ['Popen'] @@ -18,10 +14,8 @@ # class _DupFd(object): - def __init__(self, fd): self.fd = fd - def detach(self): return self.fd @@ -29,33 +23,30 @@ def detach(self): # Start child process using a fresh interpreter # - class Popen(popen_fork.Popen): method = 'spawn' DupFd = _DupFd def __init__(self, process_obj): self._fds = [] - super(Popen, self).__init__(process_obj) + super().__init__(process_obj) def duplicate_for_child(self, fd): self._fds.append(fd) return fd def _launch(self, process_obj): - os.environ["MULTIPROCESSING_FORKING_DISABLE"] = "1" - spawn._Django_old_layout_hack__save() - from . import semaphore_tracker - tracker_fd = semaphore_tracker.getfd() + from . import resource_tracker + tracker_fd = resource_tracker.getfd() self._fds.append(tracker_fd) prep_data = spawn.get_preparation_data(process_obj._name) fp = io.BytesIO() - context.set_spawning_popen(self) + set_spawning_popen(self) try: reduction.dump(prep_data, fp) reduction.dump(process_obj, fp) finally: - context.set_spawning_popen(None) + set_spawning_popen(None) parent_r = child_w = child_r = parent_w = None try: @@ -64,13 +55,18 @@ def _launch(self, process_obj): cmd = spawn.get_command_line(tracker_fd=tracker_fd, pipe_handle=child_r) self._fds.extend([child_r, child_w]) - self.pid = spawnv_passfds( - spawn.get_executable(), cmd, self._fds, - ) + self.pid = util.spawnv_passfds(spawn.get_executable(), + cmd, self._fds) self.sentinel = parent_r - with io.open(parent_w, 'wb', closefd=False) as f: - f.write(fp.getvalue()) + with open(parent_w, 'wb', closefd=False) as f: + f.write(fp.getbuffer()) finally: - for fd in (child_r, child_w, parent_w): + fds_to_close = [] + for fd in (parent_r, parent_w): + if fd is not None: + fds_to_close.append(fd) + self.finalizer = util.Finalize(self, util.close_fds, fds_to_close) + + for fd in (child_r, child_w): if fd is not None: os.close(fd) diff --git a/billiard/popen_spawn_win32.py b/billiard/popen_spawn_win32.py index 20fcc771..9c4098d0 100644 --- a/billiard/popen_spawn_win32.py +++ b/billiard/popen_spawn_win32.py @@ -1,16 +1,12 @@ -from __future__ import absolute_import - -import io import os import msvcrt import signal import sys +import _winapi -from . import context +from .context import reduction, get_spawning_popen, set_spawning_popen from . import spawn -from . import reduction - -from .compat import _winapi +from . import util __all__ = ['Popen'] @@ -22,50 +18,62 @@ WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") -# -# We define a Popen class similar to the one from subprocess, but -# whose constructor takes a process object as its argument. -# +def _path_eq(p1, p2): + return p1 == p2 or os.path.normcase(p1) == os.path.normcase(p2) -if sys.platform == 'win32': - try: - from _winapi import CreateProcess, GetExitCodeProcess - close_thread_handle = _winapi.CloseHandle - except ImportError: # Py2.7 - from _subprocess import CreateProcess, GetExitCodeProcess +WINENV = not _path_eq(sys.executable, sys._base_executable) - def close_thread_handle(handle): - handle.Close() +def _close_handles(*handles): + for handle in handles: + _winapi.CloseHandle(handle) + + +# +# We define a Popen class similar to the one from subprocess, but +# whose constructor takes a process object as its argument. +# class Popen(object): ''' Start a subprocess to run the code of a process object ''' method = 'spawn' - sentinel = None def __init__(self, process_obj): - os.environ["MULTIPROCESSING_FORKING_DISABLE"] = "1" - spawn._Django_old_layout_hack__save() prep_data = spawn.get_preparation_data(process_obj._name) - # read end of pipe will be "stolen" by the child process + # read end of pipe will be duplicated by the child process # -- see spawn_main() in spawn.py. + # + # bpo-33929: Previously, the read end of pipe was "stolen" by the child + # process, but it leaked a handle if the child process had been + # terminated before it could steal the handle from the parent process. rhandle, whandle = _winapi.CreatePipe(None, 0) wfd = msvcrt.open_osfhandle(whandle, 0) cmd = spawn.get_command_line(parent_pid=os.getpid(), pipe_handle=rhandle) cmd = ' '.join('"%s"' % x for x in cmd) - with io.open(wfd, 'wb', closefd=True) as to_child: + python_exe = spawn.get_executable() + + # bpo-35797: When running in a venv, we bypass the redirect + # executor and launch our base Python. + if WINENV and _path_eq(python_exe, sys.executable): + python_exe = sys._base_executable + env = os.environ.copy() + env["__PYVENV_LAUNCHER__"] = sys.executable + else: + env = None + + with open(wfd, 'wb', closefd=True) as to_child: # start process try: - hp, ht, pid, tid = CreateProcess( - spawn.get_executable(), cmd, - None, None, False, 0, None, None, None) - close_thread_handle(ht) + hp, ht, pid, tid = _winapi.CreateProcess( + python_exe, cmd, + None, None, False, 0, env, None, None) + _winapi.CloseHandle(ht) except: _winapi.CloseHandle(rhandle) raise @@ -75,24 +83,19 @@ def __init__(self, process_obj): self.returncode = None self._handle = hp self.sentinel = int(hp) + self.finalizer = util.Finalize(self, _close_handles, + (self.sentinel, int(rhandle))) # send information to child - context.set_spawning_popen(self) + set_spawning_popen(self) try: reduction.dump(prep_data, to_child) reduction.dump(process_obj, to_child) finally: - context.set_spawning_popen(None) - - def close(self): - if self.sentinel is not None: - try: - _winapi.CloseHandle(self.sentinel) - finally: - self.sentinel = None + set_spawning_popen(None) def duplicate_for_child(self, handle): - assert self is context.get_spawning_popen() + assert self is get_spawning_popen() return reduction.duplicate(handle, self.sentinel) def wait(self, timeout=None): @@ -104,7 +107,7 @@ def wait(self, timeout=None): res = _winapi.WaitForSingleObject(int(self._handle), msecs) if res == _winapi.WAIT_OBJECT_0: - code = GetExitCodeProcess(self._handle) + code = _winapi.GetExitCodeProcess(self._handle) if code == TERMINATE: code = -signal.SIGTERM self.returncode = code @@ -121,3 +124,8 @@ def terminate(self): except OSError: if self.wait(timeout=1.0) is None: raise + + kill = terminate + + def close(self): + self.finalizer() diff --git a/billiard/process.py b/billiard/process.py index 59e3ba4c..0b2e0b45 100644 --- a/billiard/process.py +++ b/billiard/process.py @@ -6,7 +6,9 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import + +__all__ = ['BaseProcess', 'current_process', 'active_children', + 'parent_process'] # # Imports @@ -16,37 +18,45 @@ import sys import signal import itertools -import logging import threading from _weakrefset import WeakSet -from multiprocessing import process as _mproc - -from .five import items, string_t +# +# +# try: ORIGINAL_DIR = os.path.abspath(os.getcwd()) except OSError: ORIGINAL_DIR = None -__all__ = ['BaseProcess', 'Process', 'current_process', 'active_children'] - # # Public functions # - def current_process(): ''' Return process object representing the current process ''' return _current_process +def active_children(): + ''' + Return list of process objects corresponding to live child processes + ''' + _cleanup() + return list(_children) + -def _set_current_process(process): - global _current_process - _current_process = _mproc._current_process = process +def parent_process(): + ''' + Return process object representing the parent process + ''' + return _parent_process +# +# +# def _cleanup(): # check for processes which have finished @@ -54,57 +64,41 @@ def _cleanup(): if p._popen.poll() is not None: _children.discard(p) - -def _maybe_flush(f): - try: - f.flush() - except (AttributeError, EnvironmentError, NotImplementedError): - pass - - -def active_children(_cleanup=_cleanup): - ''' - Return list of process objects corresponding to live child processes - ''' - try: - _cleanup() - except TypeError: - # called after gc collect so _cleanup does not exist anymore - return [] - return list(_children) - +# +# The `Process` class +# class BaseProcess(object): ''' Process objects represent activity that is run in a separate process - The class is analagous to `threading.Thread` + The class is analogous to `threading.Thread` ''' - def _Popen(self): - raise NotImplementedError() + raise NotImplementedError - def __init__(self, group=None, target=None, name=None, - args=(), kwargs={}, daemon=None, **_kw): + def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, + *, daemon=None): assert group is None, 'group argument must be None for now' count = next(_process_counter) - self._identity = _current_process._identity + (count, ) + self._identity = _current_process._identity + (count,) self._config = _current_process._config.copy() self._parent_pid = os.getpid() + self._parent_name = _current_process.name self._popen = None + self._closed = False self._target = target self._args = tuple(args) self._kwargs = dict(kwargs) - self._name = ( - name or type(self).__name__ + '-' + - ':'.join(str(i) for i in self._identity) - ) + self._name = name or type(self).__name__ + '-' + \ + ':'.join(str(i) for i in self._identity) if daemon is not None: self.daemon = daemon - if _dangling is not None: - _dangling.add(self) - - self._controlled_termination = False + _dangling.add(self) + + def _check_closed(self): + if self._closed: + raise ValueError("process object is closed") def run(self): ''' @@ -117,63 +111,88 @@ def start(self): ''' Start child process ''' + self._check_closed() assert self._popen is None, 'cannot start a process twice' assert self._parent_pid == os.getpid(), \ - 'can only start a process object created by current process' + 'can only start a process object created by current process' + assert not _current_process._config.get('daemon'), \ + 'daemonic processes are not allowed to have children' _cleanup() self._popen = self._Popen(self) self._sentinel = self._popen.sentinel + # Avoid a refcycle if the target function holds an indirect + # reference to the process object (see bpo-30775) + del self._target, self._args, self._kwargs _children.add(self) - def close(self): - if self._popen is not None: - self._popen.close() - def terminate(self): ''' Terminate process; sends SIGTERM signal or uses TerminateProcess() ''' + self._check_closed() self._popen.terminate() - - def terminate_controlled(self): - self._controlled_termination = True - self.terminate() + + def kill(self): + ''' + Terminate process; sends SIGKILL signal or uses TerminateProcess() + ''' + self._check_closed() + self._popen.kill() def join(self, timeout=None): ''' Wait until child process terminates ''' + self._check_closed() assert self._parent_pid == os.getpid(), 'can only join a child process' assert self._popen is not None, 'can only join a started process' res = self._popen.wait(timeout) if res is not None: _children.discard(self) - self.close() def is_alive(self): ''' Return whether process is alive ''' + self._check_closed() if self is _current_process: return True assert self._parent_pid == os.getpid(), 'can only test a child process' + if self._popen is None: return False - self._popen.poll() - return self._popen.returncode is None - def _is_alive(self): - if self._popen is None: + returncode = self._popen.poll() + if returncode is None: + return True + else: + _children.discard(self) return False - return self._popen.poll() is None + + def close(self): + ''' + Close the Process object. + + This method releases resources held by the Process object. It is + an error to call this method if the child process is still running. + ''' + if self._popen is not None: + if self._popen.poll() is None: + raise ValueError("Cannot close a process while it is still running. " + "You should first call join() or terminate().") + self._popen.close() + self._popen = None + del self._sentinel + _children.discard(self) + self._closed = True @property def name(self): return self._name @name.setter - def name(self, name): # noqa - assert isinstance(name, string_t), 'name must be a string' + def name(self, name): + assert isinstance(name, str), 'name must be a string' self._name = name @property @@ -183,7 +202,7 @@ def daemon(self): ''' return self._config.get('daemon', False) - @daemon.setter # noqa + @daemon.setter def daemon(self, daemonic): ''' Set whether process is a daemon @@ -195,7 +214,7 @@ def daemon(self, daemonic): def authkey(self): return self._config['authkey'] - @authkey.setter # noqa + @authkey.setter def authkey(self, authkey): ''' Set authorization key of process @@ -207,6 +226,7 @@ def exitcode(self): ''' Return exit code of process or `None` if it has yet to stop ''' + self._check_closed() if self._popen is None: return self._popen return self._popen.poll() @@ -216,6 +236,7 @@ def ident(self): ''' Return identifier (PID) of process or `None` if it has yet to start ''' + self._check_closed() if self is _current_process: return os.getpid() else: @@ -229,92 +250,59 @@ def sentinel(self): Return a file descriptor (Unix) or handle (Windows) suitable for waiting for process termination. ''' + self._check_closed() try: return self._sentinel except AttributeError: - raise ValueError("process not started") - - @property - def _counter(self): - # compat for 2.7 - return _process_counter - - @property - def _children(self): - # compat for 2.7 - return _children - - @property - def _authkey(self): - # compat for 2.7 - return self.authkey - - @property - def _daemonic(self): - # compat for 2.7 - return self.daemon - - @property - def _tempdir(self): - # compat for 2.7 - return self._config.get('tempdir') + raise ValueError("process not started") from None def __repr__(self): + exitcode = None if self is _current_process: status = 'started' + elif self._closed: + status = 'closed' elif self._parent_pid != os.getpid(): status = 'unknown' elif self._popen is None: status = 'initial' else: - if self._popen.poll() is not None: - status = self.exitcode - else: - status = 'started' - - if type(status) is int: - if status == 0: + exitcode = self._popen.poll() + if exitcode is not None: status = 'stopped' else: - status = 'stopped[%s]' % _exitcode_to_name.get(status, status) + status = 'started' - return '<%s(%s, %s%s)>' % (type(self).__name__, self._name, - status, self.daemon and ' daemon' or '') + info = [type(self).__name__, 'name=%r' % self._name] + if self._popen is not None: + info.append('pid=%s' % self._popen.pid) + info.append('parent=%s' % self._parent_pid) + info.append(status) + if exitcode is not None: + exitcode = _exitcode_to_name.get(exitcode, exitcode) + info.append('exitcode=%s' % exitcode) + if self.daemon: + info.append('daemon') + return '<%s>' % ' '.join(info) ## - def _bootstrap(self): + def _bootstrap(self, parent_sentinel=None): from . import util, context - global _current_process, _process_counter, _children + global _current_process, _parent_process, _process_counter, _children try: if self._start_method is not None: context._force_start_method(self._start_method) _process_counter = itertools.count(1) _children = set() - if sys.stdin is not None: - try: - sys.stdin.close() - sys.stdin = open(os.devnull) - except (OSError, ValueError): - pass + util._close_stdin() old_process = _current_process - _set_current_process(self) - - # Re-init logging system. - # Workaround for https://bugs.python.org/issue6721/#msg140215 - # Python logging module uses RLock() objects which are broken - # after fork. This can result in a deadlock (Celery Issue #496). - loggerDict = logging.Logger.manager.loggerDict - logger_names = list(loggerDict.keys()) - logger_names.append(None) # for root logger - for name in logger_names: - if not name or not isinstance(loggerDict[name], - logging.PlaceHolder): - for handler in logging.getLogger(name).handlers: - handler.createLock() - logging._lock = threading.RLock() - + _current_process = self + _parent_process = _ParentProcess( + self._parent_name, self._parent_pid, parent_sentinel) + if threading._HAVE_THREAD_NATIVE_ID: + threading.main_thread()._set_native_id() try: util._finalizer_registry.clear() util._run_after_forkers() @@ -322,32 +310,29 @@ def _bootstrap(self): # delay finalization of the old process object until after # _run_after_forkers() is executed del old_process - util.info('child process %s calling self.run()', self.pid) + util.info('child process calling self.run()') try: self.run() exitcode = 0 finally: util._exit_function() - except SystemExit as exc: - if not exc.args: - exitcode = 1 - elif isinstance(exc.args[0], int): - exitcode = exc.args[0] + except SystemExit as e: + if e.code is None: + exitcode = 0 + elif isinstance(e.code, int): + exitcode = e.code else: - sys.stderr.write(str(exc.args[0]) + '\n') - _maybe_flush(sys.stderr) - exitcode = 0 if isinstance(exc.args[0], str) else 1 + sys.stderr.write(str(e.code) + '\n') + exitcode = 1 except: exitcode = 1 - if not util.error('Process %s', self.name, exc_info=True): - import traceback - sys.stderr.write('Process %s:\n' % self.name) - traceback.print_exc() + import traceback + sys.stderr.write('Process %s:\n' % self.name) + traceback.print_exc() finally: - util.info('process %s exiting with exitcode %d', - self.pid, exitcode) - _maybe_flush(sys.stdout) - _maybe_flush(sys.stderr) + threading._shutdown() + util.info('process exiting with exitcode %d' % exitcode) + util._flush_std_streams() return exitcode @@ -355,22 +340,53 @@ def _bootstrap(self): # We subclass bytes to avoid accidental transmission of auth keys over network # - class AuthenticationString(bytes): - def __reduce__(self): from .context import get_spawning_popen - if get_spawning_popen() is None: raise TypeError( 'Pickling an AuthenticationString object is ' - 'disallowed for security reasons') + 'disallowed for security reasons' + ) return AuthenticationString, (bytes(self),) + # -# Create object representing the main process +# Create object representing the parent process # +class _ParentProcess(BaseProcess): + + def __init__(self, name, pid, sentinel): + self._identity = () + self._name = name + self._pid = pid + self._parent_pid = None + self._popen = None + self._closed = False + self._sentinel = sentinel + self._config = {} + + def is_alive(self): + from multiprocessing.connection import wait + return not wait([self._sentinel], timeout=0) + + @property + def ident(self): + return self._pid + + def join(self, timeout=None): + ''' + Wait until parent process terminates + ''' + from multiprocessing.connection import wait + wait([self._sentinel], timeout=timeout) + + pid = ident + +# +# Create object representing the main process +# class _MainProcess(BaseProcess): @@ -379,26 +395,38 @@ def __init__(self): self._name = 'MainProcess' self._parent_pid = None self._popen = None + self._closed = False self._config = {'authkey': AuthenticationString(os.urandom(32)), 'semprefix': '/mp'} + # Note that some versions of FreeBSD only allow named + # semaphores to have names of up to 14 characters. Therefore + # we choose a short prefix. + # + # On MacOSX in a sandbox it may be necessary to use a + # different prefix -- see #19478. + # + # Everything in self._config will be inherited by descendant + # processes. + def close(self): + pass + + +_parent_process = None _current_process = _MainProcess() _process_counter = itertools.count(1) _children = set() del _MainProcess - -Process = BaseProcess - # # Give names to some return codes # _exitcode_to_name = {} -for name, signum in items(signal.__dict__): - if name[:3] == 'SIG' and '_' not in name: - _exitcode_to_name[-signum] = name +for name, signum in list(signal.__dict__.items()): + if name[:3]=='SIG' and '_' not in name: + _exitcode_to_name[-signum] = f'-{name}' # For debug and leak testing _dangling = WeakSet() diff --git a/billiard/queues.py b/billiard/queues.py index 8e0511b7..a2901814 100644 --- a/billiard/queues.py +++ b/billiard/queues.py @@ -6,40 +6,38 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import + +__all__ = ['Queue', 'SimpleQueue', 'JoinableQueue'] import sys import os import threading import collections +import time +import types import weakref import errno +from queue import Empty, Full + +import _multiprocessing + from . import connection from . import context +_ForkingPickler = context.reduction.ForkingPickler -from .compat import get_errno -from .five import monotonic, Empty, Full -from .util import ( - debug, error, info, Finalize, register_after_fork, is_exiting, -) -from .reduction import ForkingPickler - -__all__ = ['Queue', 'SimpleQueue', 'JoinableQueue'] +from .util import debug, info, Finalize, register_after_fork, is_exiting +# +# Queue type using a pipe, buffer and thread +# class Queue(object): - ''' - Queue type using a pipe, buffer and thread - ''' - def __init__(self, maxsize=0, *args, **kwargs): - try: - ctx = kwargs['ctx'] - except KeyError: - raise TypeError('missing 1 required keyword-only argument: ctx') + + def __init__(self, maxsize=0, *, ctx): if maxsize <= 0: # Can raise ImportError (see issues #3770 and #23400) - from .synchronize import SEM_VALUE_MAX as maxsize # noqa + from .synchronize import SEM_VALUE_MAX as maxsize self._maxsize = maxsize self._reader, self._writer = connection.Pipe(duplex=False) self._rlock = ctx.Lock() @@ -51,8 +49,7 @@ def __init__(self, maxsize=0, *args, **kwargs): self._sem = ctx.BoundedSemaphore(maxsize) # For use by concurrent.futures self._ignore_epipe = False - - self._after_fork() + self._reset() if sys.platform != 'win32': register_after_fork(self, Queue._after_fork) @@ -65,25 +62,30 @@ def __getstate__(self): def __setstate__(self, state): (self._ignore_epipe, self._maxsize, self._reader, self._writer, self._rlock, self._wlock, self._sem, self._opid) = state - self._after_fork() + self._reset() def _after_fork(self): debug('Queue._after_fork()') - self._notempty = threading.Condition(threading.Lock()) + self._reset(after_fork=True) + + def _reset(self, after_fork=False): + if after_fork: + self._notempty._at_fork_reinit() + else: + self._notempty = threading.Condition(threading.Lock()) self._buffer = collections.deque() self._thread = None self._jointhread = None self._joincancelled = False self._closed = False self._close = None - self._send_bytes = self._writer.send - self._recv = self._reader.recv self._send_bytes = self._writer.send_bytes self._recv_bytes = self._reader.recv_bytes self._poll = self._reader.poll def put(self, obj, block=True, timeout=None): - assert not self._closed + if self._closed: + raise ValueError(f"Queue {self!r} is closed") if not self._sem.acquire(block, timeout): raise Full @@ -94,20 +96,21 @@ def put(self, obj, block=True, timeout=None): self._notempty.notify() def get(self, block=True, timeout=None): + if self._closed: + raise ValueError(f"Queue {self!r} is closed") if block and timeout is None: with self._rlock: res = self._recv_bytes() self._sem.release() - else: if block: - deadline = monotonic() + timeout + deadline = time.monotonic() + timeout if not self._rlock.acquire(block, timeout): raise Empty try: if block: - timeout = deadline - monotonic() - if timeout < 0 or not self._poll(timeout): + timeout = deadline - time.monotonic() + if not self._poll(timeout): raise Empty elif not self._poll(): raise Empty @@ -116,11 +119,10 @@ def get(self, block=True, timeout=None): finally: self._rlock.release() # unserialize the data after having released the lock - return ForkingPickler.loads(res) + return _ForkingPickler.loads(res) def qsize(self): - # Raises NotImplementedError on macOS because - # of broken sem_getvalue() + # Raises NotImplementedError on Mac OSX because of broken sem_getvalue() return self._maxsize - self._sem._semlock._get_value() def empty(self): @@ -147,7 +149,7 @@ def close(self): def join_thread(self): debug('Queue.join_thread()') - assert self._closed + assert self._closed, "Queue {0!r} not closed".format(self) if self._jointhread: self._jointhread() @@ -167,7 +169,8 @@ def _start_thread(self): self._thread = threading.Thread( target=Queue._feed, args=(self._buffer, self._notempty, self._send_bytes, - self._wlock, self._writer.close, self._ignore_epipe), + self._wlock, self._writer.close, self._ignore_epipe, + self._on_queue_feeder_error, self._sem), name='QueueFeederThread' ) self._thread.daemon = True @@ -176,26 +179,19 @@ def _start_thread(self): self._thread.start() debug('... done self._thread.start()') - # On process exit we will wait for data to be flushed to pipe. - # - # However, if this process created the queue then all - # processes which use the queue will be descendants of this - # process. Therefore waiting for the queue to be flushed - # is pointless once all the child processes have been joined. - created_by_this_process = (self._opid == os.getpid()) - if not self._joincancelled and not created_by_this_process: + if not self._joincancelled: self._jointhread = Finalize( self._thread, Queue._finalize_join, [weakref.ref(self._thread)], exitpriority=-5 - ) + ) # Send sentinel to the thread queue object when garbage collected self._close = Finalize( self, Queue._finalize_close, [self._buffer, self._notempty], exitpriority=10 - ) + ) @staticmethod def _finalize_join(twr): @@ -215,9 +211,9 @@ def _finalize_close(buffer, notempty): notempty.notify() @staticmethod - def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe): + def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe, + onerror, queue_sem): debug('starting thread to feed data to pipe') - nacquire = notempty.acquire nrelease = notempty.release nwait = notempty.wait @@ -229,8 +225,8 @@ def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe): else: wacquire = None - try: - while 1: + while 1: + try: nacquire() try: if not buffer: @@ -246,7 +242,7 @@ def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe): return # serialize the data before acquiring the lock - obj = ForkingPickler.dumps(obj) + obj = _ForkingPickler.dumps(obj) if wacquire is None: send_bytes(obj) else: @@ -257,41 +253,48 @@ def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe): wrelease() except IndexError: pass - except Exception as exc: - if ignore_epipe and get_errno(exc) == errno.EPIPE: - return - # Since this runs in a daemon thread the resources it uses - # may be become unusable while the process is cleaning up. - # We ignore errors which happen after the process has - # started to cleanup. - try: + except Exception as e: + if ignore_epipe and getattr(e, 'errno', 0) == errno.EPIPE: + return + # Since this runs in a daemon thread the resources it uses + # may be become unusable while the process is cleaning up. + # We ignore errors which happen after the process has + # started to cleanup. if is_exiting(): - info('error in queue thread: %r', exc, exc_info=True) + info('error in queue thread: %s', e) + return else: - if not error('error in queue thread: %r', exc, - exc_info=True): - import traceback - traceback.print_exc() - except Exception: - pass + # Since the object has not been sent in the queue, we need + # to decrease the size of the queue. The error acts as + # if the object had been silently removed from the queue + # and this step is necessary to have a properly working + # queue. + queue_sem.release() + onerror(e, obj) + + @staticmethod + def _on_queue_feeder_error(e, obj): + """ + Private API hook called when feeding data in the background thread + raises an exception. For overriding by concurrent.futures. + """ + import traceback + traceback.print_exc() + _sentinel = object() +# +# A queue type which also supports join() and task_done() methods +# +# Note that if you do not call task_done() for each finished task then +# eventually the counter's semaphore may overflow causing Bad Things +# to happen. +# class JoinableQueue(Queue): - ''' - A queue type which also supports join() and task_done() methods - Note that if you do not call task_done() for each finished task then - eventually the counter's semaphore may overflow causing Bad Things - to happen. - ''' - - def __init__(self, maxsize=0, *args, **kwargs): - try: - ctx = kwargs['ctx'] - except KeyError: - raise TypeError('missing 1 required keyword argument: ctx') + def __init__(self, maxsize=0, *, ctx): Queue.__init__(self, maxsize, ctx=ctx) self._unfinished_tasks = ctx.Semaphore(0) self._cond = ctx.Condition() @@ -304,17 +307,17 @@ def __setstate__(self, state): self._cond, self._unfinished_tasks = state[-2:] def put(self, obj, block=True, timeout=None): - assert not self._closed + if self._closed: + raise ValueError(f"Queue {self!r} is closed") if not self._sem.acquire(block, timeout): raise Full - with self._notempty: - with self._cond: - if self._thread is None: - self._start_thread() - self._buffer.append(obj) - self._unfinished_tasks.release() - self._notempty.notify() + with self._notempty, self._cond: + if self._thread is None: + self._start_thread() + self._buffer.append(obj) + self._unfinished_tasks.release() + self._notempty.notify() def task_done(self): with self._cond: @@ -328,18 +331,24 @@ def join(self): if not self._unfinished_tasks._semlock._is_zero(): self._cond.wait() +# +# Simplified Queue type -- really just a locked pipe +# -class _SimpleQueue(object): - ''' - Simplified Queue type -- really just a locked pipe - ''' +class SimpleQueue(object): - def __init__(self, rnonblock=False, wnonblock=False, ctx=None): - self._reader, self._writer = connection.Pipe( - duplex=False, rnonblock=rnonblock, wnonblock=wnonblock, - ) + def __init__(self, *, ctx): + self._reader, self._writer = connection.Pipe(duplex=False) + self._rlock = ctx.Lock() self._poll = self._reader.poll - self._rlock = self._wlock = None + if sys.platform == 'win32': + self._wlock = None + else: + self._wlock = ctx.Lock() + + def close(self): + self._reader.close() + self._writer.close() def empty(self): return not self._poll() @@ -350,41 +359,22 @@ def __getstate__(self): def __setstate__(self, state): (self._reader, self._writer, self._rlock, self._wlock) = state - - def get_payload(self): - return self._reader.recv_bytes() - - def send_payload(self, value): - self._writer.send_bytes(value) + self._poll = self._reader.poll def get(self): + with self._rlock: + res = self._reader.recv_bytes() # unserialize the data after having released the lock - return ForkingPickler.loads(self.get_payload()) + return _ForkingPickler.loads(res) def put(self, obj): # serialize the data before acquiring the lock - self.send_payload(ForkingPickler.dumps(obj)) - - -class SimpleQueue(_SimpleQueue): - - def __init__(self, *args, **kwargs): - try: - ctx = kwargs['ctx'] - except KeyError: - raise TypeError('missing required keyword argument: ctx') - self._reader, self._writer = connection.Pipe(duplex=False) - self._rlock = ctx.Lock() - self._wlock = ctx.Lock() if sys.platform != 'win32' else None - - def get_payload(self): - with self._rlock: - return self._reader.recv_bytes() - - def send_payload(self, value): + obj = _ForkingPickler.dumps(obj) if self._wlock is None: # writes to a message oriented win32 pipe are atomic - self._writer.send_bytes(value) + self._writer.send_bytes(obj) else: with self._wlock: - self._writer.send_bytes(value) + self._writer.send_bytes(obj) + + __class_getitem__ = classmethod(types.GenericAlias) diff --git a/billiard/reduction.py b/billiard/reduction.py index a2fae25d..5593f068 100644 --- a/billiard/reduction.py +++ b/billiard/reduction.py @@ -6,8 +6,9 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import +from abc import ABCMeta +import copyreg import functools import io import os @@ -19,8 +20,6 @@ __all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump'] -PY3 = sys.version_info[0] == 3 - HAVE_SEND_HANDLE = (sys.platform == 'win32' or (hasattr(socket, 'CMSG_LEN') and @@ -31,69 +30,31 @@ # Pickler subclass # +class ForkingPickler(pickle.Pickler): + '''Pickler subclass used by multiprocessing.''' + _extra_reducers = {} + _copyreg_dispatch_table = copyreg.dispatch_table -if PY3: - import copyreg - - class ForkingPickler(pickle.Pickler): - '''Pickler subclass used by multiprocessing.''' - _extra_reducers = {} - _copyreg_dispatch_table = copyreg.dispatch_table - - def __init__(self, *args): - super(ForkingPickler, self).__init__(*args) - self.dispatch_table = self._copyreg_dispatch_table.copy() - self.dispatch_table.update(self._extra_reducers) + def __init__(self, *args): + super().__init__(*args) + self.dispatch_table = self._copyreg_dispatch_table.copy() + self.dispatch_table.update(self._extra_reducers) - @classmethod - def register(cls, type, reduce): - '''Register a reduce function for a type.''' - cls._extra_reducers[type] = reduce + @classmethod + def register(cls, type, reduce): + '''Register a reduce function for a type.''' + cls._extra_reducers[type] = reduce - @classmethod - def dumps(cls, obj, protocol=None): - buf = io.BytesIO() - cls(buf, protocol).dump(obj) - return buf.getbuffer() + @classmethod + def dumps(cls, obj, protocol=None): + buf = io.BytesIO() + cls(buf, protocol).dump(obj) + return buf.getbuffer() - @classmethod - def loadbuf(cls, buf, protocol=None): - return cls.loads(buf.getbuffer()) - - loads = pickle.loads - -else: + loads = pickle.loads - class ForkingPickler(pickle.Pickler): # noqa - '''Pickler subclass used by multiprocessing.''' - dispatch = pickle.Pickler.dispatch.copy() - - @classmethod - def register(cls, type, reduce): - '''Register a reduce function for a type.''' - def dispatcher(self, obj): - rv = reduce(obj) - self.save_reduce(obj=obj, *rv) - cls.dispatch[type] = dispatcher - - @classmethod - def dumps(cls, obj, protocol=None): - buf = io.BytesIO() - cls(buf, protocol).dump(obj) - return buf.getvalue() - - @classmethod - def loadbuf(cls, buf, protocol=None): - return cls.loads(buf.getvalue()) - - @classmethod - def loads(cls, buf, loads=pickle.loads): - if isinstance(buf, io.BytesIO): - buf = buf.getvalue() - return loads(buf) register = ForkingPickler.register - def dump(obj, file, protocol=None): '''Replacement for pickle.dump() using ForkingPickler.''' ForkingPickler(file, protocol).dump(obj) @@ -105,14 +66,18 @@ def dump(obj, file, protocol=None): if sys.platform == 'win32': # Windows __all__ += ['DupHandle', 'duplicate', 'steal_handle'] - from .compat import _winapi + import _winapi - def duplicate(handle, target_process=None, inheritable=False): + def duplicate(handle, target_process=None, inheritable=False, + *, source_process=None): '''Duplicate a handle. (target_process is a handle not a pid!)''' + current_process = _winapi.GetCurrentProcess() + if source_process is None: + source_process = current_process if target_process is None: - target_process = _winapi.GetCurrentProcess() + target_process = current_process return _winapi.DuplicateHandle( - _winapi.GetCurrentProcess(), handle, target_process, + source_process, handle, target_process, 0, inheritable, _winapi.DUPLICATE_SAME_ACCESS) def steal_handle(source_pid, handle): @@ -174,7 +139,7 @@ def detach(self): __all__ += ['DupFd', 'sendfds', 'recvfds'] import array - # On macOS we should acknowledge receipt of fds -- see Issue14669 + # On MacOSX we should acknowledge receipt of fds -- see Issue14669 ACKNOWLEDGE = sys.platform == 'darwin' def sendfds(sock, fds): @@ -189,40 +154,38 @@ def recvfds(sock, size): '''Receive an array of fds over an AF_UNIX socket.''' a = array.array('i') bytes_size = a.itemsize * size - msg, ancdata, flags, addr = sock.recvmsg( - 1, socket.CMSG_LEN(bytes_size), - ) + msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_SPACE(bytes_size)) if not msg and not ancdata: raise EOFError try: if ACKNOWLEDGE: sock.send(b'A') if len(ancdata) != 1: - raise RuntimeError( - 'received %d items of ancdata' % len(ancdata), - ) + raise RuntimeError('received %d items of ancdata' % + len(ancdata)) cmsg_level, cmsg_type, cmsg_data = ancdata[0] if (cmsg_level == socket.SOL_SOCKET and - cmsg_type == socket.SCM_RIGHTS): + cmsg_type == socket.SCM_RIGHTS): if len(cmsg_data) % a.itemsize != 0: raise ValueError a.frombytes(cmsg_data) - assert len(a) % 256 == msg[0] + if len(a) % 256 != msg[0]: + raise AssertionError( + "Len is {0:n} but msg[0] is {1!r}".format( + len(a), msg[0])) return list(a) except (ValueError, IndexError): pass raise RuntimeError('Invalid data received') - def send_handle(conn, handle, destination_pid): # noqa + def send_handle(conn, handle, destination_pid): '''Send a handle over a local connection.''' - fd = conn.fileno() - with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s: + with socket.fromfd(conn.fileno(), socket.AF_UNIX, socket.SOCK_STREAM) as s: sendfds(s, [handle]) - def recv_handle(conn): # noqa + def recv_handle(conn): '''Receive a handle over a local connection.''' - fd = conn.fileno() - with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s: + with socket.fromfd(conn.fileno(), socket.AF_UNIX, socket.SOCK_STREAM) as s: return recvfds(s, 1)[0] def DupFd(fd): @@ -240,14 +203,11 @@ def DupFd(fd): # Try making some callable types picklable # - def _reduce_method(m): if m.__self__ is None: return getattr, (m.__class__, m.__func__.__name__) else: return getattr, (m.__self__, m.__func__.__name__) - - class _C: def f(self): pass @@ -262,8 +222,6 @@ def _reduce_method_descriptor(m): def _reduce_partial(p): return _rebuild_partial, (p.func, p.args, p.keywords or {}) - - def _rebuild_partial(func, args, keywords): return functools.partial(func, *args, **keywords) register(functools.partial, _reduce_partial) @@ -273,22 +231,51 @@ def _rebuild_partial(func, args, keywords): # if sys.platform == 'win32': - def _reduce_socket(s): from .resource_sharer import DupSocket return _rebuild_socket, (DupSocket(s),) - def _rebuild_socket(ds): return ds.detach() register(socket.socket, _reduce_socket) else: - - def _reduce_socket(s): # noqa + def _reduce_socket(s): df = DupFd(s.fileno()) return _rebuild_socket, (df, s.family, s.type, s.proto) - - def _rebuild_socket(df, family, type, proto): # noqa + def _rebuild_socket(df, family, type, proto): fd = df.detach() return socket.socket(family, type, proto, fileno=fd) register(socket.socket, _reduce_socket) + + +class AbstractReducer(metaclass=ABCMeta): + '''Abstract base class for use in implementing a Reduction class + suitable for use in replacing the standard reduction mechanism + used in multiprocessing.''' + ForkingPickler = ForkingPickler + register = register + dump = dump + send_handle = send_handle + recv_handle = recv_handle + + if sys.platform == 'win32': + steal_handle = steal_handle + duplicate = duplicate + DupHandle = DupHandle + else: + sendfds = sendfds + recvfds = recvfds + DupFd = DupFd + + _reduce_method = _reduce_method + _reduce_method_descriptor = _reduce_method_descriptor + _rebuild_partial = _rebuild_partial + _reduce_socket = _reduce_socket + _rebuild_socket = _rebuild_socket + + def __init__(self, *args): + register(type(_C().f), _reduce_method) + register(type(list.append), _reduce_method_descriptor) + register(type(int.__add__), _reduce_method_descriptor) + register(functools.partial, _reduce_partial) + register(socket.socket, _reduce_socket) diff --git a/billiard/resource_sharer.py b/billiard/resource_sharer.py index e10f6cdd..66076509 100644 --- a/billiard/resource_sharer.py +++ b/billiard/resource_sharer.py @@ -1,13 +1,12 @@ # -# We use a background thread for sharing fds on Unix, and for sharing -# sockets on Windows. +# We use a background thread for sharing fds on Unix, and for sharing sockets on +# Windows. # # A client which wants to pickle a resource registers it with the resource # sharer and gets an identifier in return. The unpickling process will connect # to the resource sharer, sends the identifier and its pid, and then receives # the resource. # -from __future__ import absolute_import import os import signal @@ -16,7 +15,7 @@ import threading from . import process -from . import reduction +from .context import reduction from . import util __all__ = ['stop'] @@ -27,10 +26,8 @@ class DupSocket(object): '''Picklable wrapper for a socket.''' - def __init__(self, sock): new_sock = sock.dup() - def send(conn, pid): share = new_sock.share(pid) conn.send_bytes(share) @@ -49,10 +46,8 @@ class DupFd(object): '''Wrapper for fd which can be used at any time.''' def __init__(self, fd): new_fd = os.dup(fd) - def send(conn, pid): reduction.send_handle(conn, new_fd, pid) - def close(): os.close(new_fd) self._id = _resource_sharer.register(send, close) @@ -64,11 +59,10 @@ def detach(self): class _ResourceSharer(object): - '''Manager for resouces using background thread.''' + '''Manager for resources using background thread.''' def __init__(self): self._key = 0 self._cache = {} - self._old_locks = [] self._lock = threading.Lock() self._listener = None self._address = None @@ -118,10 +112,7 @@ def _afterfork(self): for key, (send, close) in self._cache.items(): close() self._cache.clear() - # If self._lock was locked at the time of the fork, it may be broken - # -- see issue 6721. Replace it without letting it be gc'ed. - self._old_locks.append(self._lock) - self._lock = threading.Lock() + self._lock._at_fork_reinit() if self._listener is not None: self._listener.close() self._listener = None @@ -130,7 +121,7 @@ def _afterfork(self): def _start(self): from .connection import Listener - assert self._listener is None + assert self._listener is None, "Already have Listener" util.debug('starting listener and thread for sending handles') self._listener = Listener(authkey=process.current_process().authkey) self._address = self._listener.address @@ -141,7 +132,7 @@ def _start(self): def _serve(self): if hasattr(signal, 'pthread_sigmask'): - signal.pthread_sigmask(signal.SIG_BLOCK, range(1, signal.NSIG)) + signal.pthread_sigmask(signal.SIG_BLOCK, signal.valid_signals()) while 1: try: with self._listener.accept() as conn: diff --git a/billiard/resource_tracker.py b/billiard/resource_tracker.py new file mode 100644 index 00000000..c9bfa9b8 --- /dev/null +++ b/billiard/resource_tracker.py @@ -0,0 +1,231 @@ +############################################################################### +# Server process to keep track of unlinked resources (like shared memory +# segments, semaphores etc.) and clean them. +# +# On Unix we run a server process which keeps track of unlinked +# resources. The server ignores SIGINT and SIGTERM and reads from a +# pipe. Every other process of the program has a copy of the writable +# end of the pipe, so we get EOF when all other processes have exited. +# Then the server process unlinks any remaining resource names. +# +# This is important because there may be system limits for such resources: for +# instance, the system only supports a limited number of named semaphores, and +# shared-memory segments live in the RAM. If a python process leaks such a +# resource, this resource will not be removed till the next reboot. Without +# this resource tracker process, "killall python" would probably leave unlinked +# resources. + +import os +import signal +import sys +import threading +import warnings + +from . import spawn +from . import util + +__all__ = ['ensure_running', 'register', 'unregister'] + +_HAVE_SIGMASK = hasattr(signal, 'pthread_sigmask') +_IGNORED_SIGNALS = (signal.SIGINT, signal.SIGTERM) + +_CLEANUP_FUNCS = { + 'noop': lambda: None, +} + +if os.name == 'posix': + import _multiprocessing + import _posixshmem + + _CLEANUP_FUNCS.update({ + 'semaphore': _multiprocessing.sem_unlink, + 'shared_memory': _posixshmem.shm_unlink, + }) + + +class ResourceTracker(object): + + def __init__(self): + self._lock = threading.Lock() + self._fd = None + self._pid = None + + def _stop(self): + with self._lock: + if self._fd is None: + # not running + return + + # closing the "alive" file descriptor stops main() + os.close(self._fd) + self._fd = None + + os.waitpid(self._pid, 0) + self._pid = None + + def getfd(self): + self.ensure_running() + return self._fd + + def ensure_running(self): + '''Make sure that resource tracker process is running. + + This can be run from any process. Usually a child process will use + the resource created by its parent.''' + with self._lock: + if self._fd is not None: + # resource tracker was launched before, is it still running? + if self._check_alive(): + # => still alive + return + # => dead, launch it again + os.close(self._fd) + + # Clean-up to avoid dangling processes. + try: + # _pid can be None if this process is a child from another + # python process, which has started the resource_tracker. + if self._pid is not None: + os.waitpid(self._pid, 0) + except ChildProcessError: + # The resource_tracker has already been terminated. + pass + self._fd = None + self._pid = None + + warnings.warn('resource_tracker: process died unexpectedly, ' + 'relaunching. Some resources might leak.') + + fds_to_pass = [] + try: + fds_to_pass.append(sys.stderr.fileno()) + except Exception: + pass + cmd = 'from multiprocessing.resource_tracker import main;main(%d)' + r, w = os.pipe() + try: + fds_to_pass.append(r) + # process will out live us, so no need to wait on pid + exe = spawn.get_executable() + args = [exe] + util._args_from_interpreter_flags() + args += ['-c', cmd % r] + # bpo-33613: Register a signal mask that will block the signals. + # This signal mask will be inherited by the child that is going + # to be spawned and will protect the child from a race condition + # that can make the child die before it registers signal handlers + # for SIGINT and SIGTERM. The mask is unregistered after spawning + # the child. + try: + if _HAVE_SIGMASK: + signal.pthread_sigmask(signal.SIG_BLOCK, _IGNORED_SIGNALS) + pid = util.spawnv_passfds(exe, args, fds_to_pass) + finally: + if _HAVE_SIGMASK: + signal.pthread_sigmask(signal.SIG_UNBLOCK, _IGNORED_SIGNALS) + except: + os.close(w) + raise + else: + self._fd = w + self._pid = pid + finally: + os.close(r) + + def _check_alive(self): + '''Check that the pipe has not been closed by sending a probe.''' + try: + # We cannot use send here as it calls ensure_running, creating + # a cycle. + os.write(self._fd, b'PROBE:0:noop\n') + except OSError: + return False + else: + return True + + def register(self, name, rtype): + '''Register name of resource with resource tracker.''' + self._send('REGISTER', name, rtype) + + def unregister(self, name, rtype): + '''Unregister name of resource with resource tracker.''' + self._send('UNREGISTER', name, rtype) + + def _send(self, cmd, name, rtype): + self.ensure_running() + msg = '{0}:{1}:{2}\n'.format(cmd, name, rtype).encode('ascii') + if len(name) > 512: + # posix guarantees that writes to a pipe of less than PIPE_BUF + # bytes are atomic, and that PIPE_BUF >= 512 + raise ValueError('name too long') + nbytes = os.write(self._fd, msg) + assert nbytes == len(msg), "nbytes {0:n} but len(msg) {1:n}".format( + nbytes, len(msg)) + + +_resource_tracker = ResourceTracker() +ensure_running = _resource_tracker.ensure_running +register = _resource_tracker.register +unregister = _resource_tracker.unregister +getfd = _resource_tracker.getfd + +def main(fd): + '''Run resource tracker.''' + # protect the process from ^C and "killall python" etc + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + if _HAVE_SIGMASK: + signal.pthread_sigmask(signal.SIG_UNBLOCK, _IGNORED_SIGNALS) + + for f in (sys.stdin, sys.stdout): + try: + f.close() + except Exception: + pass + + cache = {rtype: set() for rtype in _CLEANUP_FUNCS.keys()} + try: + # keep track of registered/unregistered resources + with open(fd, 'rb') as f: + for line in f: + try: + cmd, name, rtype = line.strip().decode('ascii').split(':') + cleanup_func = _CLEANUP_FUNCS.get(rtype, None) + if cleanup_func is None: + raise ValueError( + f'Cannot register {name} for automatic cleanup: ' + f'unknown resource type {rtype}') + + if cmd == 'REGISTER': + cache[rtype].add(name) + elif cmd == 'UNREGISTER': + cache[rtype].remove(name) + elif cmd == 'PROBE': + pass + else: + raise RuntimeError('unrecognized command %r' % cmd) + except Exception: + try: + sys.excepthook(*sys.exc_info()) + except: + pass + finally: + # all processes have terminated; cleanup any remaining resources + for rtype, rtype_cache in cache.items(): + if rtype_cache: + try: + warnings.warn('resource_tracker: There appear to be %d ' + 'leaked %s objects to clean up at shutdown' % + (len(rtype_cache), rtype)) + except Exception: + pass + for name in rtype_cache: + # For some reason the process which created and registered this + # resource has failed to unregister it. Presumably it has + # died. We therefore unlink it. + try: + try: + _CLEANUP_FUNCS[rtype](name) + except Exception as e: + warnings.warn('resource_tracker: %r: %s' % (name, e)) + finally: + pass diff --git a/billiard/shared_memory.py b/billiard/shared_memory.py new file mode 100644 index 00000000..a3a5fcf4 --- /dev/null +++ b/billiard/shared_memory.py @@ -0,0 +1,530 @@ +"""Provides shared memory for direct access across processes. + +The API of this package is currently provisional. Refer to the +documentation for details. +""" + + +__all__ = [ 'SharedMemory', 'ShareableList' ] + + +from functools import partial +import mmap +import os +import errno +import struct +import secrets +import types + +if os.name == "nt": + import _winapi + _USE_POSIX = False +else: + import _posixshmem + _USE_POSIX = True + + +_O_CREX = os.O_CREAT | os.O_EXCL + +# FreeBSD (and perhaps other BSDs) limit names to 14 characters. +_SHM_SAFE_NAME_LENGTH = 14 + +# Shared memory block name prefix +if _USE_POSIX: + _SHM_NAME_PREFIX = '/psm_' +else: + _SHM_NAME_PREFIX = 'wnsm_' + + +def _make_filename(): + "Create a random filename for the shared memory object." + # number of random bytes to use for name + nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2 + assert nbytes >= 2, '_SHM_NAME_PREFIX too long' + name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes) + assert len(name) <= _SHM_SAFE_NAME_LENGTH + return name + + +class SharedMemory: + """Creates a new shared memory block or attaches to an existing + shared memory block. + + Every shared memory block is assigned a unique name. This enables + one process to create a shared memory block with a particular name + so that a different process can attach to that same shared memory + block using that same name. + + As a resource for sharing data across processes, shared memory blocks + may outlive the original process that created them. When one process + no longer needs access to a shared memory block that might still be + needed by other processes, the close() method should be called. + When a shared memory block is no longer needed by any process, the + unlink() method should be called to ensure proper cleanup.""" + + # Defaults; enables close() and unlink() to run without errors. + _name = None + _fd = -1 + _mmap = None + _buf = None + _flags = os.O_RDWR + _mode = 0o600 + _prepend_leading_slash = True if _USE_POSIX else False + + def __init__(self, name=None, create=False, size=0): + if not size >= 0: + raise ValueError("'size' must be a positive integer") + if create: + self._flags = _O_CREX | os.O_RDWR + if name is None and not self._flags & os.O_EXCL: + raise ValueError("'name' can only be None if create=True") + + if _USE_POSIX: + + # POSIX Shared Memory + + if name is None: + while True: + name = _make_filename() + try: + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + except FileExistsError: + continue + self._name = name + break + else: + name = "/" + name if self._prepend_leading_slash else name + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + self._name = name + try: + if create and size: + os.ftruncate(self._fd, size) + stats = os.fstat(self._fd) + size = stats.st_size + self._mmap = mmap.mmap(self._fd, size) + except OSError: + self.unlink() + raise + + from .resource_tracker import register + register(self._name, "shared_memory") + + else: + + # Windows Named Shared Memory + + if create: + while True: + temp_name = _make_filename() if name is None else name + # Create and reserve shared memory block with this name + # until it can be attached to by mmap. + h_map = _winapi.CreateFileMapping( + _winapi.INVALID_HANDLE_VALUE, + _winapi.NULL, + _winapi.PAGE_READWRITE, + (size >> 32) & 0xFFFFFFFF, + size & 0xFFFFFFFF, + temp_name + ) + try: + last_error_code = _winapi.GetLastError() + if last_error_code == _winapi.ERROR_ALREADY_EXISTS: + if name is not None: + raise FileExistsError( + errno.EEXIST, + os.strerror(errno.EEXIST), + name, + _winapi.ERROR_ALREADY_EXISTS + ) + else: + continue + self._mmap = mmap.mmap(-1, size, tagname=temp_name) + finally: + _winapi.CloseHandle(h_map) + self._name = temp_name + break + + else: + self._name = name + # Dynamically determine the existing named shared memory + # block's size which is likely a multiple of mmap.PAGESIZE. + h_map = _winapi.OpenFileMapping( + _winapi.FILE_MAP_READ, + False, + name + ) + try: + p_buf = _winapi.MapViewOfFile( + h_map, + _winapi.FILE_MAP_READ, + 0, + 0, + 0 + ) + finally: + _winapi.CloseHandle(h_map) + size = _winapi.VirtualQuerySize(p_buf) + self._mmap = mmap.mmap(-1, size, tagname=name) + + self._size = size + self._buf = memoryview(self._mmap) + + def __del__(self): + try: + self.close() + except OSError: + pass + + def __reduce__(self): + return ( + self.__class__, + ( + self.name, + False, + self.size, + ), + ) + + def __repr__(self): + return f'{self.__class__.__name__}({self.name!r}, size={self.size})' + + @property + def buf(self): + "A memoryview of contents of the shared memory block." + return self._buf + + @property + def name(self): + "Unique name that identifies the shared memory block." + reported_name = self._name + if _USE_POSIX and self._prepend_leading_slash: + if self._name.startswith("/"): + reported_name = self._name[1:] + return reported_name + + @property + def size(self): + "Size in bytes." + return self._size + + def close(self): + """Closes access to the shared memory from this instance but does + not destroy the shared memory block.""" + if self._buf is not None: + self._buf.release() + self._buf = None + if self._mmap is not None: + self._mmap.close() + self._mmap = None + if _USE_POSIX and self._fd >= 0: + os.close(self._fd) + self._fd = -1 + + def unlink(self): + """Requests that the underlying shared memory block be destroyed. + + In order to ensure proper cleanup of resources, unlink should be + called once (and only once) across all processes which have access + to the shared memory block.""" + if _USE_POSIX and self._name: + from .resource_tracker import unregister + _posixshmem.shm_unlink(self._name) + unregister(self._name, "shared_memory") + + +_encoding = "utf8" + +class ShareableList: + """Pattern for a mutable list-like object shareable via a shared + memory block. It differs from the built-in list type in that these + lists can not change their overall length (i.e. no append, insert, + etc.) + + Because values are packed into a memoryview as bytes, the struct + packing format for any storable value must require no more than 8 + characters to describe its format.""" + + # The shared memory area is organized as follows: + # - 8 bytes: number of items (N) as a 64-bit integer + # - (N + 1) * 8 bytes: offsets of each element from the start of the + # data area + # - K bytes: the data area storing item values (with encoding and size + # depending on their respective types) + # - N * 8 bytes: `struct` format string for each element + # - N bytes: index into _back_transforms_mapping for each element + # (for reconstructing the corresponding Python value) + _types_mapping = { + int: "q", + float: "d", + bool: "xxxxxxx?", + str: "%ds", + bytes: "%ds", + None.__class__: "xxxxxx?x", + } + _alignment = 8 + _back_transforms_mapping = { + 0: lambda value: value, # int, float, bool + 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str + 2: lambda value: value.rstrip(b'\x00'), # bytes + 3: lambda _value: None, # None + } + + @staticmethod + def _extract_recreation_code(value): + """Used in concert with _back_transforms_mapping to convert values + into the appropriate Python objects when retrieving them from + the list as well as when storing them.""" + if not isinstance(value, (str, bytes, None.__class__)): + return 0 + elif isinstance(value, str): + return 1 + elif isinstance(value, bytes): + return 2 + else: + return 3 # NoneType + + def __init__(self, sequence=None, *, name=None): + if name is None or sequence is not None: + sequence = sequence or () + _formats = [ + self._types_mapping[type(item)] + if not isinstance(item, (str, bytes)) + else self._types_mapping[type(item)] % ( + self._alignment * (len(item) // self._alignment + 1), + ) + for item in sequence + ] + self._list_len = len(_formats) + assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len + offset = 0 + # The offsets of each list element into the shared memory's + # data area (0 meaning the start of the data area, not the start + # of the shared memory area). + self._allocated_offsets = [0] + for fmt in _formats: + offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1]) + self._allocated_offsets.append(offset) + _recreation_codes = [ + self._extract_recreation_code(item) for item in sequence + ] + requested_size = struct.calcsize( + "q" + self._format_size_metainfo + + "".join(_formats) + + self._format_packing_metainfo + + self._format_back_transform_codes + ) + + self.shm = SharedMemory(name, create=True, size=requested_size) + else: + self.shm = SharedMemory(name) + + if sequence is not None: + _enc = _encoding + struct.pack_into( + "q" + self._format_size_metainfo, + self.shm.buf, + 0, + self._list_len, + *(self._allocated_offsets) + ) + struct.pack_into( + "".join(_formats), + self.shm.buf, + self._offset_data_start, + *(v.encode(_enc) if isinstance(v, str) else v for v in sequence) + ) + struct.pack_into( + self._format_packing_metainfo, + self.shm.buf, + self._offset_packing_formats, + *(v.encode(_enc) for v in _formats) + ) + struct.pack_into( + self._format_back_transform_codes, + self.shm.buf, + self._offset_back_transform_codes, + *(_recreation_codes) + ) + + else: + self._list_len = len(self) # Obtains size from offset 0 in buffer. + self._allocated_offsets = list( + struct.unpack_from( + self._format_size_metainfo, + self.shm.buf, + 1 * 8 + ) + ) + + def _get_packing_format(self, position): + "Gets the packing format for a single value stored in the list." + position = position if position >= 0 else position + self._list_len + if (position >= self._list_len) or (self._list_len < 0): + raise IndexError("Requested position out of range.") + + v = struct.unpack_from( + "8s", + self.shm.buf, + self._offset_packing_formats + position * 8 + )[0] + fmt = v.rstrip(b'\x00') + fmt_as_str = fmt.decode(_encoding) + + return fmt_as_str + + def _get_back_transform(self, position): + "Gets the back transformation function for a single value." + + if (position >= self._list_len) or (self._list_len < 0): + raise IndexError("Requested position out of range.") + + transform_code = struct.unpack_from( + "b", + self.shm.buf, + self._offset_back_transform_codes + position + )[0] + transform_function = self._back_transforms_mapping[transform_code] + + return transform_function + + def _set_packing_format_and_transform(self, position, fmt_as_str, value): + """Sets the packing format and back transformation code for a + single value in the list at the specified position.""" + + if (position >= self._list_len) or (self._list_len < 0): + raise IndexError("Requested position out of range.") + + struct.pack_into( + "8s", + self.shm.buf, + self._offset_packing_formats + position * 8, + fmt_as_str.encode(_encoding) + ) + + transform_code = self._extract_recreation_code(value) + struct.pack_into( + "b", + self.shm.buf, + self._offset_back_transform_codes + position, + transform_code + ) + + def __getitem__(self, position): + position = position if position >= 0 else position + self._list_len + try: + offset = self._offset_data_start + self._allocated_offsets[position] + (v,) = struct.unpack_from( + self._get_packing_format(position), + self.shm.buf, + offset + ) + except IndexError: + raise IndexError("index out of range") + + back_transform = self._get_back_transform(position) + v = back_transform(v) + + return v + + def __setitem__(self, position, value): + position = position if position >= 0 else position + self._list_len + try: + item_offset = self._allocated_offsets[position] + offset = self._offset_data_start + item_offset + current_format = self._get_packing_format(position) + except IndexError: + raise IndexError("assignment index out of range") + + if not isinstance(value, (str, bytes)): + new_format = self._types_mapping[type(value)] + encoded_value = value + else: + allocated_length = self._allocated_offsets[position + 1] - item_offset + + encoded_value = (value.encode(_encoding) + if isinstance(value, str) else value) + if len(encoded_value) > allocated_length: + raise ValueError("bytes/str item exceeds available storage") + if current_format[-1] == "s": + new_format = current_format + else: + new_format = self._types_mapping[str] % ( + allocated_length, + ) + + self._set_packing_format_and_transform( + position, + new_format, + value + ) + struct.pack_into(new_format, self.shm.buf, offset, encoded_value) + + def __reduce__(self): + return partial(self.__class__, name=self.shm.name), () + + def __len__(self): + return struct.unpack_from("q", self.shm.buf, 0)[0] + + def __repr__(self): + return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})' + + @property + def format(self): + "The struct packing format used by all currently stored items." + return "".join( + self._get_packing_format(i) for i in range(self._list_len) + ) + + @property + def _format_size_metainfo(self): + "The struct packing format used for the items' storage offsets." + return "q" * (self._list_len + 1) + + @property + def _format_packing_metainfo(self): + "The struct packing format used for the items' packing formats." + return "8s" * self._list_len + + @property + def _format_back_transform_codes(self): + "The struct packing format used for the items' back transforms." + return "b" * self._list_len + + @property + def _offset_data_start(self): + # - 8 bytes for the list length + # - (N + 1) * 8 bytes for the element offsets + return (self._list_len + 2) * 8 + + @property + def _offset_packing_formats(self): + return self._offset_data_start + self._allocated_offsets[-1] + + @property + def _offset_back_transform_codes(self): + return self._offset_packing_formats + self._list_len * 8 + + def count(self, value): + "L.count(value) -> integer -- return number of occurrences of value." + + return sum(value == entry for entry in self) + + def index(self, value): + """L.index(value) -> integer -- return first index of value. + Raises ValueError if the value is not present.""" + + for position, entry in enumerate(self): + if value == entry: + return position + else: + raise ValueError(f"{value!r} not in this container") + + __class_getitem__ = classmethod(types.GenericAlias) diff --git a/billiard/sharedctypes.py b/billiard/sharedctypes.py index 97675df4..60717070 100644 --- a/billiard/sharedctypes.py +++ b/billiard/sharedctypes.py @@ -6,38 +6,41 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import import ctypes -import sys import weakref from . import heap from . import get_context -from .context import assert_spawning -from .five import int_types -from .reduction import ForkingPickler + +from .context import reduction, assert_spawning +_ForkingPickler = reduction.ForkingPickler __all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized'] -PY3 = sys.version_info[0] == 3 +# +# +# typecode_to_type = { - 'c': ctypes.c_char, 'u': ctypes.c_wchar, - 'b': ctypes.c_byte, 'B': ctypes.c_ubyte, - 'h': ctypes.c_short, 'H': ctypes.c_ushort, - 'i': ctypes.c_int, 'I': ctypes.c_uint, - 'l': ctypes.c_long, 'L': ctypes.c_ulong, - 'f': ctypes.c_float, 'd': ctypes.c_double -} + 'c': ctypes.c_char, 'u': ctypes.c_wchar, + 'b': ctypes.c_byte, 'B': ctypes.c_ubyte, + 'h': ctypes.c_short, 'H': ctypes.c_ushort, + 'i': ctypes.c_int, 'I': ctypes.c_uint, + 'l': ctypes.c_long, 'L': ctypes.c_ulong, + 'q': ctypes.c_longlong, 'Q': ctypes.c_ulonglong, + 'f': ctypes.c_float, 'd': ctypes.c_double + } +# +# +# def _new_value(type_): size = ctypes.sizeof(type_) wrapper = heap.BufferWrapper(size) return rebuild_ctype(type_, wrapper, None) - def RawValue(typecode_or_type, *args): ''' Returns a ctypes object allocated from shared memory @@ -48,13 +51,12 @@ def RawValue(typecode_or_type, *args): obj.__init__(*args) return obj - def RawArray(typecode_or_type, size_or_initializer): ''' Returns a ctypes array allocated from shared memory ''' type_ = typecode_to_type.get(typecode_or_type, typecode_or_type) - if isinstance(size_or_initializer, int_types): + if isinstance(size_or_initializer, int): type_ = type_ * size_or_initializer obj = _new_value(type_) ctypes.memset(ctypes.addressof(obj), 0, ctypes.sizeof(obj)) @@ -65,16 +67,10 @@ def RawArray(typecode_or_type, size_or_initializer): result.__init__(*size_or_initializer) return result - -def Value(typecode_or_type, *args, **kwds): +def Value(typecode_or_type, *args, lock=True, ctx=None): ''' Return a synchronization wrapper for a Value ''' - lock = kwds.pop('lock', None) - ctx = kwds.pop('ctx', None) - if kwds: - raise ValueError( - 'unrecognized keyword argument(s): %s' % list(kwds.keys())) obj = RawValue(typecode_or_type, *args) if lock is False: return obj @@ -82,19 +78,13 @@ def Value(typecode_or_type, *args, **kwds): ctx = ctx or get_context() lock = ctx.RLock() if not hasattr(lock, 'acquire'): - raise AttributeError("'%r' has no method 'acquire'" % lock) + raise AttributeError("%r has no method 'acquire'" % lock) return synchronized(obj, lock, ctx=ctx) - -def Array(typecode_or_type, size_or_initializer, **kwds): +def Array(typecode_or_type, size_or_initializer, *, lock=True, ctx=None): ''' Return a synchronization wrapper for a RawArray ''' - lock = kwds.pop('lock', None) - ctx = kwds.pop('ctx', None) - if kwds: - raise ValueError( - 'unrecognized keyword argument(s): %s' % list(kwds.keys())) obj = RawArray(typecode_or_type, size_or_initializer) if lock is False: return obj @@ -102,16 +92,14 @@ def Array(typecode_or_type, size_or_initializer, **kwds): ctx = ctx or get_context() lock = ctx.RLock() if not hasattr(lock, 'acquire'): - raise AttributeError("'%r' has no method 'acquire'" % lock) + raise AttributeError("%r has no method 'acquire'" % lock) return synchronized(obj, lock, ctx=ctx) - def copy(obj): new_obj = _new_value(type(obj)) ctypes.pointer(new_obj)[0] = obj return new_obj - def synchronized(obj, lock=None, ctx=None): assert not isinstance(obj, SynchronizedBase), 'object already synchronized' ctx = ctx or get_context() @@ -128,7 +116,7 @@ def synchronized(obj, lock=None, ctx=None): scls = class_cache[cls] except KeyError: names = [field[0] for field in cls._fields_] - d = dict((name, make_property(name)) for name in names) + d = {name: make_property(name) for name in names} classname = 'Synchronized' + cls.__name__ scls = class_cache[cls] = type(classname, (SynchronizedBase,), d) return scls(obj, lock, ctx) @@ -137,7 +125,6 @@ def synchronized(obj, lock=None, ctx=None): # Functions for pickling/unpickling # - def reduce_ctype(obj): assert_spawning(obj) if isinstance(obj, ctypes.Array): @@ -145,16 +132,12 @@ def reduce_ctype(obj): else: return rebuild_ctype, (type(obj), obj._wrapper, None) - def rebuild_ctype(type_, wrapper, length): if length is not None: type_ = type_ * length - ForkingPickler.register(type_, reduce_ctype) - if PY3: - buf = wrapper.create_memoryview() - obj = type_.from_buffer(buf) - else: - obj = type_.from_address(wrapper.get_address()) + _ForkingPickler.register(type_, reduce_ctype) + buf = wrapper.create_memoryview() + obj = type_.from_buffer(buf) obj._wrapper = wrapper return obj @@ -162,17 +145,15 @@ def rebuild_ctype(type_, wrapper, length): # Function to create properties # - def make_property(name): try: return prop_cache[name] except KeyError: d = {} - exec(template % ((name, ) * 7), d) + exec(template % ((name,)*7), d) prop_cache[name] = d[name] return d[name] - template = ''' def get%s(self): self.acquire() @@ -196,7 +177,6 @@ def set%s(self, value): # Synchronized wrappers # - class SynchronizedBase(object): def __init__(self, obj, lock=None, ctx=None): diff --git a/billiard/spawn.py b/billiard/spawn.py index 9d04b234..7cc129e2 100644 --- a/billiard/spawn.py +++ b/billiard/spawn.py @@ -7,30 +7,20 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import -import io import os -import pickle import sys import runpy import types -import warnings from . import get_start_method, set_start_method from . import process +from .context import reduction from . import util __all__ = ['_main', 'freeze_support', 'set_executable', 'get_executable', 'get_preparation_data', 'get_command_line', 'import_main_path'] -W_OLD_DJANGO_LAYOUT = """\ -Will add directory %r to path! This is necessary to accommodate \ -pre-Django 1.4 layouts using setup_environ. -You can skip this warning by adding a DJANGO_SETTINGS_MODULE=settings \ -environment variable. -""" - # # _python_exe is the assumed path to the python executable. # People embedding Python want to modify it. @@ -40,7 +30,7 @@ WINEXE = False WINSERVICE = False else: - WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) + WINEXE = getattr(sys, 'frozen', False) WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") if WINSERVICE: @@ -48,58 +38,10 @@ else: _python_exe = sys.executable - -def _module_parent_dir(mod): - dir, filename = os.path.split(_module_dir(mod)) - if dir == os.curdir or not dir: - dir = os.getcwd() - return dir - - -def _module_dir(mod): - if '__init__.py' in mod.__file__: - return os.path.dirname(mod.__file__) - return mod.__file__ - - -def _Django_old_layout_hack__save(): - if 'DJANGO_PROJECT_DIR' not in os.environ: - try: - settings_name = os.environ['DJANGO_SETTINGS_MODULE'] - except KeyError: - return # not using Django. - - conf_settings = sys.modules.get('django.conf.settings') - configured = conf_settings and conf_settings.configured - try: - project_name, _ = settings_name.split('.', 1) - except ValueError: - return # not modified by setup_environ - - project = __import__(project_name) - try: - project_dir = os.path.normpath(_module_parent_dir(project)) - except AttributeError: - return # dynamically generated module (no __file__) - if configured: - warnings.warn(UserWarning( - W_OLD_DJANGO_LAYOUT % os.path.realpath(project_dir) - )) - os.environ['DJANGO_PROJECT_DIR'] = project_dir - - -def _Django_old_layout_hack__load(): - try: - sys.path.append(os.environ['DJANGO_PROJECT_DIR']) - except KeyError: - pass - - def set_executable(exe): global _python_exe _python_exe = exe - def get_executable(): return _python_exe @@ -107,12 +49,11 @@ def get_executable(): # # - def is_forking(argv): ''' Return whether commandline indicates we are forking ''' - if len(argv) >= 2 and argv[1] == '--billiard-fork': + if len(argv) >= 2 and argv[1] == '--multiprocessing-fork': return True else: return False @@ -139,75 +80,53 @@ def get_command_line(**kwds): Returns prefix of command line used for spawning a child process ''' if getattr(sys, 'frozen', False): - return ([sys.executable, '--billiard-fork'] + + return ([sys.executable, '--multiprocessing-fork'] + ['%s=%r' % item for item in kwds.items()]) else: - prog = 'from billiard.spawn import spawn_main; spawn_main(%s)' + prog = 'from multiprocessing.spawn import spawn_main; spawn_main(%s)' prog %= ', '.join('%s=%r' % item for item in kwds.items()) opts = util._args_from_interpreter_flags() - return [_python_exe] + opts + ['-c', prog, '--billiard-fork'] + return [_python_exe] + opts + ['-c', prog, '--multiprocessing-fork'] def spawn_main(pipe_handle, parent_pid=None, tracker_fd=None): ''' Run code specified by data received over pipe ''' - assert is_forking(sys.argv) + assert is_forking(sys.argv), "Not forking" if sys.platform == 'win32': import msvcrt - from .reduction import steal_handle - new_handle = steal_handle(parent_pid, pipe_handle) + import _winapi + + if parent_pid is not None: + source_process = _winapi.OpenProcess( + _winapi.SYNCHRONIZE | _winapi.PROCESS_DUP_HANDLE, + False, parent_pid) + else: + source_process = None + new_handle = reduction.duplicate(pipe_handle, + source_process=source_process) fd = msvcrt.open_osfhandle(new_handle, os.O_RDONLY) + parent_sentinel = source_process else: - from . import semaphore_tracker - semaphore_tracker._semaphore_tracker._fd = tracker_fd + from . import resource_tracker + resource_tracker._resource_tracker._fd = tracker_fd fd = pipe_handle - exitcode = _main(fd) + parent_sentinel = os.dup(pipe_handle) + exitcode = _main(fd, parent_sentinel) sys.exit(exitcode) -def _setup_logging_in_child_hack(): - # Huge hack to make logging before Process.run work. - try: - os.environ["MP_MAIN_FILE"] = sys.modules["__main__"].__file__ - except KeyError: - pass - except AttributeError: - pass - loglevel = os.environ.get("_MP_FORK_LOGLEVEL_") - logfile = os.environ.get("_MP_FORK_LOGFILE_") or None - format = os.environ.get("_MP_FORK_LOGFORMAT_") - if loglevel: - from . import util - import logging - logger = util.get_logger() - logger.setLevel(int(loglevel)) - if not logger.handlers: - logger._rudimentary_setup = True - logfile = logfile or sys.__stderr__ - if hasattr(logfile, "write"): - handler = logging.StreamHandler(logfile) - else: - handler = logging.FileHandler(logfile) - formatter = logging.Formatter( - format or util.DEFAULT_LOGGING_FORMAT, - ) - handler.setFormatter(formatter) - logger.addHandler(handler) - - -def _main(fd): - _Django_old_layout_hack__load() - with io.open(fd, 'rb', closefd=True) as from_parent: +def _main(fd, parent_sentinel): + with os.fdopen(fd, 'rb', closefd=True) as from_parent: process.current_process()._inheriting = True try: - preparation_data = pickle.load(from_parent) + preparation_data = reduction.pickle.load(from_parent) prepare(preparation_data) - _setup_logging_in_child_hack() - self = pickle.load(from_parent) + self = reduction.pickle.load(from_parent) finally: del process.current_process()._inheriting - return self._bootstrap() + return self._bootstrap(parent_sentinel) def _check_not_importing_main(): @@ -236,12 +155,12 @@ def get_preparation_data(name): d = dict( log_to_stderr=util._log_to_stderr, authkey=process.current_process().authkey, - ) + ) if util._logger is not None: d['log_level'] = util._logger.getEffectiveLevel() - sys_path = sys.path[:] + sys_path=sys.path.copy() try: i = sys_path.index('') except ValueError: @@ -256,22 +175,19 @@ def get_preparation_data(name): orig_dir=process.ORIGINAL_DIR, dir=os.getcwd(), start_method=get_start_method(), - ) + ) # Figure out whether to initialise main in the subprocess as a module # or through direct execution (or to leave it alone entirely) main_module = sys.modules['__main__'] - try: - main_mod_name = main_module.__spec__.name - except AttributeError: - main_mod_name = main_module.__name__ + main_mod_name = getattr(main_module.__spec__, "name", None) if main_mod_name is not None: d['init_main_from_name'] = main_mod_name elif sys.platform != 'win32' or (not WINEXE and not WINSERVICE): main_path = getattr(main_module, '__file__', None) if main_path is not None: if (not os.path.isabs(main_path) and - process.ORIGINAL_DIR is not None): + process.ORIGINAL_DIR is not None): main_path = os.path.join(process.ORIGINAL_DIR, main_path) d['init_main_from_path'] = os.path.normpath(main_path) @@ -281,10 +197,8 @@ def get_preparation_data(name): # Prepare current process # - old_main_modules = [] - def prepare(data): ''' Try to get current process ready to unpickle process object @@ -314,7 +228,7 @@ def prepare(data): process.ORIGINAL_DIR = data['orig_dir'] if 'start_method' in data: - set_start_method(data['start_method']) + set_start_method(data['start_method'], force=True) if 'init_main_from_name' in data: _fixup_main_from_name(data['init_main_from_name']) @@ -323,8 +237,6 @@ def prepare(data): # Multiprocessing module helpers to fix up the main module in # spawned subprocesses - - def _fixup_main_from_name(mod_name): # __main__.py files for packages, directories, zip archives, etc, run # their "main only" code unconditionally, so we don't even try to diff --git a/billiard/synchronize.py b/billiard/synchronize.py index b97fbfb5..4fcbefc8 100644 --- a/billiard/synchronize.py +++ b/billiard/synchronize.py @@ -6,107 +6,86 @@ # Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import -import errno +__all__ = [ + 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', 'Event' + ] + +import threading import sys import tempfile -import threading +import _multiprocessing +import time from . import context from . import process from . import util -from ._ext import _billiard, ensure_SemLock -from .five import range, monotonic - -__all__ = [ - 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', 'Event', -] - # Try to import the mp.synchronize module cleanly, if it fails # raise ImportError for platforms lacking a working sem_open implementation. # See issue 3770 -ensure_SemLock() +try: + from _multiprocessing import SemLock, sem_unlink +except (ImportError): + raise ImportError("This platform lacks a functioning sem_open" + + " implementation, therefore, the required" + + " synchronization primitives needed will not" + + " function, see issue 3770.") # # Constants # RECURSIVE_MUTEX, SEMAPHORE = list(range(2)) -SEM_VALUE_MAX = _billiard.SemLock.SEM_VALUE_MAX - -try: - sem_unlink = _billiard.SemLock.sem_unlink -except AttributeError: # pragma: no cover - try: - # Py3.4+ implements sem_unlink and the semaphore must be named - from _multiprocessing import sem_unlink # noqa - except ImportError: - sem_unlink = None # noqa +SEM_VALUE_MAX = _multiprocessing.SemLock.SEM_VALUE_MAX # -# Base class for semaphores and mutexes; wraps `_billiard.SemLock` +# Base class for semaphores and mutexes; wraps `_multiprocessing.SemLock` # - -def _semname(sl): - try: - return sl.name - except AttributeError: - pass - - class SemLock(object): + _rand = tempfile._RandomNameSequence() - def __init__(self, kind, value, maxvalue, ctx=None): + def __init__(self, kind, value, maxvalue, *, ctx): if ctx is None: ctx = context._default_context.get_context() name = ctx.get_start_method() unlink_now = sys.platform == 'win32' or name == 'fork' - if sem_unlink: - for i in range(100): - try: - sl = self._semlock = _billiard.SemLock( - kind, value, maxvalue, self._make_name(), unlink_now, - ) - except (OSError, IOError) as exc: - if getattr(exc, 'errno', None) != errno.EEXIST: - raise - else: - break + for i in range(100): + try: + sl = self._semlock = _multiprocessing.SemLock( + kind, value, maxvalue, self._make_name(), + unlink_now) + except FileExistsError: + pass else: - exc = IOError('cannot find file for semaphore') - exc.errno = errno.EEXIST - raise exc + break else: - sl = self._semlock = _billiard.SemLock(kind, value, maxvalue) + raise FileExistsError('cannot find name for semaphore') - util.debug('created semlock with handle %s', sl.handle) + util.debug('created semlock with handle %s' % sl.handle) self._make_methods() - if sem_unlink: + if sys.platform != 'win32': + def _after_fork(obj): + obj._semlock._after_fork() + util.register_after_fork(self, _after_fork) - if sys.platform != 'win32': - def _after_fork(obj): - obj._semlock._after_fork() - util.register_after_fork(self, _after_fork) - - if _semname(self._semlock) is not None: - # We only get here if we are on Unix with forking - # disabled. When the object is garbage collected or the - # process shuts down we unlink the semaphore name - from .semaphore_tracker import register - register(self._semlock.name) - util.Finalize(self, SemLock._cleanup, (self._semlock.name,), - exitpriority=0) + if self._semlock.name is not None: + # We only get here if we are on Unix with forking + # disabled. When the object is garbage collected or the + # process shuts down we unlink the semaphore name + from .resource_tracker import register + register(self._semlock.name, "semaphore") + util.Finalize(self, SemLock._cleanup, (self._semlock.name,), + exitpriority=0) @staticmethod def _cleanup(name): - from .semaphore_tracker import unregister + from .resource_tracker import unregister sem_unlink(name) - unregister(name) + unregister(name, "semaphore") def _make_methods(self): self.acquire = self._semlock.acquire @@ -125,16 +104,11 @@ def __getstate__(self): h = context.get_spawning_popen().duplicate_for_child(sl.handle) else: h = sl.handle - state = (h, sl.kind, sl.maxvalue) - try: - state += (sl.name, ) - except AttributeError: - pass - return state + return (h, sl.kind, sl.maxvalue, sl.name) def __setstate__(self, state): - self._semlock = _billiard.SemLock._rebuild(*state) - util.debug('recreated blocker with handle %r', state[0]) + self._semlock = _multiprocessing.SemLock._rebuild(*state) + util.debug('recreated blocker with handle %r' % state[0]) self._make_methods() @staticmethod @@ -142,10 +116,13 @@ def _make_name(): return '%s-%s' % (process.current_process()._config['semprefix'], next(SemLock._rand)) +# +# Semaphore +# class Semaphore(SemLock): - def __init__(self, value=1, ctx=None): + def __init__(self, value=1, *, ctx): SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX, ctx=ctx) def get_value(self): @@ -158,10 +135,13 @@ def __repr__(self): value = 'unknown' return '<%s(value=%s)>' % (self.__class__.__name__, value) +# +# Bounded semaphore +# class BoundedSemaphore(Semaphore): - def __init__(self, value=1, ctx=None): + def __init__(self, value=1, *, ctx): SemLock.__init__(self, SEMAPHORE, value, value, ctx=ctx) def __repr__(self): @@ -169,16 +149,16 @@ def __repr__(self): value = self._semlock._get_value() except Exception: value = 'unknown' - return '<%s(value=%s, maxvalue=%s)>' % ( - self.__class__.__name__, value, self._semlock.maxvalue) + return '<%s(value=%s, maxvalue=%s)>' % \ + (self.__class__.__name__, value, self._semlock.maxvalue) +# +# Non-recursive lock +# class Lock(SemLock): - ''' - Non-recursive lock. - ''' - def __init__(self, ctx=None): + def __init__(self, *, ctx): SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx) def __repr__(self): @@ -197,13 +177,13 @@ def __repr__(self): name = 'unknown' return '<%s(owner=%s)>' % (self.__class__.__name__, name) +# +# Recursive lock +# class RLock(SemLock): - ''' - Recursive lock - ''' - def __init__(self, ctx=None): + def __init__(self, *, ctx): SemLock.__init__(self, RECURSIVE_MUTEX, 1, 1, ctx=ctx) def __repr__(self): @@ -223,14 +203,13 @@ def __repr__(self): name, count = 'unknown', 'unknown' return '<%s(%s, %s)>' % (self.__class__.__name__, name, count) +# +# Condition variable +# class Condition(object): - ''' - Condition variable - ''' - def __init__(self, lock=None, ctx=None): - assert ctx + def __init__(self, lock=None, *, ctx): self._lock = lock or ctx.RLock() self._sleeping_count = ctx.Semaphore(0) self._woken_count = ctx.Semaphore(0) @@ -263,12 +242,11 @@ def __repr__(self): self._woken_count._semlock._get_value()) except Exception: num_waiters = 'unknown' - return '<%s(%s, %s)>' % ( - self.__class__.__name__, self._lock, num_waiters) + return '<%s(%s, %s)>' % (self.__class__.__name__, self._lock, num_waiters) def wait(self, timeout=None): assert self._lock._semlock._is_mine(), \ - 'must acquire() condition before using wait()' + 'must acquire() condition before using wait()' # indicate that this thread is going to sleep self._sleeping_count.release() @@ -289,35 +267,21 @@ def wait(self, timeout=None): for i in range(count): self._lock.acquire() - def notify(self): - assert self._lock._semlock._is_mine(), 'lock is not owned' - assert not self._wait_semaphore.acquire(False) - - # to take account of timeouts since last notify() we subtract - # woken_count from sleeping_count and rezero woken_count - while self._woken_count.acquire(False): - res = self._sleeping_count.acquire(False) - assert res - - if self._sleeping_count.acquire(False): # try grabbing a sleeper - self._wait_semaphore.release() # wake up one sleeper - self._woken_count.acquire() # wait for sleeper to wake - - # rezero _wait_semaphore in case a timeout just happened - self._wait_semaphore.acquire(False) - - def notify_all(self): + def notify(self, n=1): assert self._lock._semlock._is_mine(), 'lock is not owned' - assert not self._wait_semaphore.acquire(False) + assert not self._wait_semaphore.acquire( + False), ('notify: Should not have been able to acquire' + + '_wait_semaphore') # to take account of timeouts since last notify*() we subtract # woken_count from sleeping_count and rezero woken_count while self._woken_count.acquire(False): res = self._sleeping_count.acquire(False) - assert res + assert res, ('notify: Bug in sleeping_count.acquire' + + '- res should not be False') sleepers = 0 - while self._sleeping_count.acquire(False): + while sleepers < n and self._sleeping_count.acquire(False): self._wait_semaphore.release() # wake up one sleeper sleepers += 1 @@ -329,29 +293,34 @@ def notify_all(self): while self._wait_semaphore.acquire(False): pass + def notify_all(self): + self.notify(n=sys.maxsize) + def wait_for(self, predicate, timeout=None): result = predicate() if result: return result if timeout is not None: - endtime = monotonic() + timeout + endtime = time.monotonic() + timeout else: endtime = None waittime = None while not result: if endtime is not None: - waittime = endtime - monotonic() + waittime = endtime - time.monotonic() if waittime <= 0: break self.wait(waittime) result = predicate() return result +# +# Event +# class Event(object): - def __init__(self, ctx=None): - assert ctx + def __init__(self, *, ctx): self._cond = ctx.Condition(ctx.Lock()) self._flag = ctx.Semaphore(0) @@ -388,50 +357,38 @@ def wait(self, timeout=None): # Barrier # +class Barrier(threading.Barrier): -if hasattr(threading, 'Barrier'): - - class Barrier(threading.Barrier): - - def __init__(self, parties, action=None, timeout=None, ctx=None): - assert ctx - import struct - from .heap import BufferWrapper - wrapper = BufferWrapper(struct.calcsize('i') * 2) - cond = ctx.Condition() - self.__setstate__((parties, action, timeout, cond, wrapper)) - self._state = 0 - self._count = 0 + def __init__(self, parties, action=None, timeout=None, *, ctx): + import struct + from .heap import BufferWrapper + wrapper = BufferWrapper(struct.calcsize('i') * 2) + cond = ctx.Condition() + self.__setstate__((parties, action, timeout, cond, wrapper)) + self._state = 0 + self._count = 0 - def __setstate__(self, state): - (self._parties, self._action, self._timeout, - self._cond, self._wrapper) = state - self._array = self._wrapper.create_memoryview().cast('i') - - def __getstate__(self): - return (self._parties, self._action, self._timeout, - self._cond, self._wrapper) - - @property - def _state(self): - return self._array[0] - - @_state.setter - def _state(self, value): # noqa - self._array[0] = value - - @property - def _count(self): - return self._array[1] + def __setstate__(self, state): + (self._parties, self._action, self._timeout, + self._cond, self._wrapper) = state + self._array = self._wrapper.create_memoryview().cast('i') - @_count.setter - def _count(self, value): # noqa - self._array[1] = value + def __getstate__(self): + return (self._parties, self._action, self._timeout, + self._cond, self._wrapper) + @property + def _state(self): + return self._array[0] -else: + @_state.setter + def _state(self, value): + self._array[0] = value - class Barrier(object): # noqa + @property + def _count(self): + return self._array[1] - def __init__(self, *args, **kwargs): - raise NotImplementedError('Barrier only supported on Py3') + @_count.setter + def _count(self, value): + self._array[1] = value diff --git a/billiard/util.py b/billiard/util.py index 1c74ccae..21f2a7eb 100644 --- a/billiard/util.py +++ b/billiard/util.py @@ -1,79 +1,29 @@ # # Module providing various facilities to other parts of the package # -# billiard/util.py +# multiprocessing/util.py # -# Copyright (c) 2006-2008, R Oudkerk --- see COPYING.txt +# Copyright (c) 2006-2008, R Oudkerk # Licensed to PSF under a Contributor Agreement. # -from __future__ import absolute_import +import os +import itertools import sys -import errno -import functools +import weakref import atexit +import threading # we want threading to install it's + # cleanup function before multiprocessing does +from subprocess import _args_from_interpreter_flags -try: - import cffi -except ImportError: - import ctypes - -try: - from subprocess import _args_from_interpreter_flags # noqa -except ImportError: # pragma: no cover - def _args_from_interpreter_flags(): # noqa - """Return a list of command-line arguments reproducing the current - settings in sys.flags and sys.warnoptions.""" - flag_opt_map = { - 'debug': 'd', - 'optimize': 'O', - 'dont_write_bytecode': 'B', - 'no_user_site': 's', - 'no_site': 'S', - 'ignore_environment': 'E', - 'verbose': 'v', - 'bytes_warning': 'b', - 'hash_randomization': 'R', - 'py3k_warning': '3', - } - args = [] - for flag, opt in flag_opt_map.items(): - v = getattr(sys.flags, flag) - if v > 0: - args.append('-' + opt * v) - for opt in sys.warnoptions: - args.append('-W' + opt) - return args - -from multiprocessing.util import ( # noqa - _afterfork_registry, - _afterfork_counter, - _exit_function, - _finalizer_registry, - _finalizer_counter, - Finalize, - ForkAwareLocal, - ForkAwareThreadLock, - get_temp_dir, - is_exiting, - register_after_fork, - _run_after_forkers, - _run_finalizers, -) - -from .compat import get_errno +from . import process __all__ = [ 'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger', 'log_to_stderr', 'get_temp_dir', 'register_after_fork', 'is_exiting', 'Finalize', 'ForkAwareThreadLock', 'ForkAwareLocal', - 'SUBDEBUG', 'SUBWARNING', -] - - -# Constants from prctl.h -PR_GET_PDEATHSIG = 2 -PR_SET_PDEATHSIG = 1 + 'close_all_fds_except', 'SUBDEBUG', 'SUBWARNING', + ] # # Logging @@ -84,8 +34,6 @@ def _args_from_interpreter_flags(): # noqa DEBUG = 10 INFO = 20 SUBWARNING = 25 -WARNING = 30 -ERROR = 40 LOGGER_NAME = 'multiprocessing' DEFAULT_LOGGING_FORMAT = '[%(levelname)s/%(processName)s] %(message)s' @@ -93,34 +41,21 @@ def _args_from_interpreter_flags(): # noqa _logger = None _log_to_stderr = False - -def sub_debug(msg, *args, **kwargs): - if _logger: - _logger.log(SUBDEBUG, msg, *args, **kwargs) - - -def debug(msg, *args, **kwargs): - if _logger: - _logger.log(DEBUG, msg, *args, **kwargs) - - -def info(msg, *args, **kwargs): +def sub_debug(msg, *args): if _logger: - _logger.log(INFO, msg, *args, **kwargs) + _logger.log(SUBDEBUG, msg, *args) - -def sub_warning(msg, *args, **kwargs): +def debug(msg, *args): if _logger: - _logger.log(SUBWARNING, msg, *args, **kwargs) + _logger.log(DEBUG, msg, *args) -def warning(msg, *args, **kwargs): +def info(msg, *args): if _logger: - _logger.log(WARNING, msg, *args, **kwargs) + _logger.log(INFO, msg, *args) -def error(msg, *args, **kwargs): +def sub_warning(msg, *args): if _logger: - _logger.log(ERROR, msg, *args, **kwargs) - + _logger.log(SUBWARNING, msg, *args) def get_logger(): ''' @@ -135,8 +70,6 @@ def get_logger(): _logger = logging.getLogger(LOGGER_NAME) _logger.propagate = 0 - logging.addLevelName(SUBDEBUG, 'SUBDEBUG') - logging.addLevelName(SUBWARNING, 'SUBWARNING') # XXX multiprocessing should cleanup before logging if hasattr(atexit, 'unregister'): @@ -145,12 +78,12 @@ def get_logger(): else: atexit._exithandlers.remove((_exit_function, (), {})) atexit._exithandlers.append((_exit_function, (), {})) + finally: logging._releaseLock() return _logger - def log_to_stderr(level=None): ''' Turn on logging and add a handler which prints to stderr @@ -170,64 +103,387 @@ def log_to_stderr(level=None): return _logger -def get_pdeathsig(): - """ - Return the current value of the parent process death signal - """ - if not sys.platform.startswith('linux'): - # currently we support only linux platform. - raise OSError() - try: - if 'cffi' in sys.modules: - ffi = cffi.FFI() - ffi.cdef("int prctl (int __option, ...);") - arg = ffi.new("int *") - C = ffi.dlopen(None) - C.prctl(PR_GET_PDEATHSIG, arg) - return arg[0] +# Abstract socket support + +def _platform_supports_abstract_sockets(): + if sys.platform == "linux": + return True + if hasattr(sys, 'getandroidapilevel'): + return True + return False + + +def is_abstract_socket_namespace(address): + if not address: + return False + if isinstance(address, bytes): + return address[0] == 0 + elif isinstance(address, str): + return address[0] == "\0" + raise TypeError('address type of {address!r} unrecognized') + + +abstract_sockets_supported = _platform_supports_abstract_sockets() + +# +# Function returning a temp directory which will be removed on exit +# + +def _remove_temp_dir(rmtree, tempdir): + rmtree(tempdir) + + current_process = process.current_process() + # current_process() can be None if the finalizer is called + # late during Python finalization + if current_process is not None: + current_process._config['tempdir'] = None + +def get_temp_dir(): + # get name of a temp directory which will be automatically cleaned up + tempdir = process.current_process()._config.get('tempdir') + if tempdir is None: + import shutil, tempfile + tempdir = tempfile.mkdtemp(prefix='pymp-') + info('created temp directory %s', tempdir) + # keep a strong reference to shutil.rmtree(), since the finalizer + # can be called late during Python shutdown + Finalize(None, _remove_temp_dir, args=(shutil.rmtree, tempdir), + exitpriority=-100) + process.current_process()._config['tempdir'] = tempdir + return tempdir + +# +# Support for reinitialization of objects when bootstrapping a child process +# + +_afterfork_registry = weakref.WeakValueDictionary() +_afterfork_counter = itertools.count() + +def _run_after_forkers(): + items = list(_afterfork_registry.items()) + items.sort() + for (index, ident, func), obj in items: + try: + func(obj) + except Exception as e: + info('after forker raised exception %s', e) + +def register_after_fork(obj, func): + _afterfork_registry[(next(_afterfork_counter), id(obj), func)] = obj + +# +# Finalization using weakrefs +# + +_finalizer_registry = {} +_finalizer_counter = itertools.count() + + +class Finalize(object): + ''' + Class which supports object finalization using weakrefs + ''' + def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None): + if (exitpriority is not None) and not isinstance(exitpriority,int): + raise TypeError( + "Exitpriority ({0!r}) must be None or int, not {1!s}".format( + exitpriority, type(exitpriority))) + + if obj is not None: + self._weakref = weakref.ref(obj, self) + elif exitpriority is None: + raise ValueError("Without object, exitpriority cannot be None") + + self._callback = callback + self._args = args + self._kwargs = kwargs or {} + self._key = (exitpriority, next(_finalizer_counter)) + self._pid = os.getpid() + + _finalizer_registry[self._key] = self + + def __call__(self, wr=None, + # Need to bind these locally because the globals can have + # been cleared at shutdown + _finalizer_registry=_finalizer_registry, + sub_debug=sub_debug, getpid=os.getpid): + ''' + Run the callback unless it has already been called or cancelled + ''' + try: + del _finalizer_registry[self._key] + except KeyError: + sub_debug('finalizer no longer registered') else: - sig = ctypes.c_int() - libc = ctypes.cdll.LoadLibrary("libc.so.6") - libc.prctl(PR_GET_PDEATHSIG, ctypes.byref(sig)) - return sig.value - except Exception: - raise OSError() - - -def set_pdeathsig(sig): - """ - Set the parent process death signal of the calling process to sig - (either a signal value in the range 1..maxsig, or 0 to clear). - This is the signal that the calling process will get when its parent dies. - This value is cleared for the child of a fork(2) and - (since Linux 2.4.36 / 2.6.23) when executing a set-user-ID or set-group-ID binary. - """ - if not sys.platform.startswith('linux'): - # currently we support only linux platform. - raise OSError() - try: - if 'cffi' in sys.modules: - ffi = cffi.FFI() - ffi.cdef("int prctl (int __option, ...);") - C = ffi.dlopen(None) - C.prctl(PR_SET_PDEATHSIG, ffi.cast("int", sig)) + if self._pid != getpid(): + sub_debug('finalizer ignored because different process') + res = None + else: + sub_debug('finalizer calling %s with args %s and kwargs %s', + self._callback, self._args, self._kwargs) + res = self._callback(*self._args, **self._kwargs) + self._weakref = self._callback = self._args = \ + self._kwargs = self._key = None + return res + + def cancel(self): + ''' + Cancel finalization of the object + ''' + try: + del _finalizer_registry[self._key] + except KeyError: + pass else: - libc = ctypes.cdll.LoadLibrary("libc.so.6") - libc.prctl(PR_SET_PDEATHSIG, sig) - except Exception: - raise OSError() + self._weakref = self._callback = self._args = \ + self._kwargs = self._key = None + + def still_active(self): + ''' + Return whether this finalizer is still waiting to invoke callback + ''' + return self._key in _finalizer_registry + + def __repr__(self): + try: + obj = self._weakref() + except (AttributeError, TypeError): + obj = None + + if obj is None: + return '<%s object, dead>' % self.__class__.__name__ + + x = '<%s object, callback=%s' % ( + self.__class__.__name__, + getattr(self._callback, '__name__', self._callback)) + if self._args: + x += ', args=' + str(self._args) + if self._kwargs: + x += ', kwargs=' + str(self._kwargs) + if self._key[0] is not None: + x += ', exitpriority=' + str(self._key[0]) + return x + '>' + + +def _run_finalizers(minpriority=None): + ''' + Run all finalizers whose exit priority is not None and at least minpriority -def _eintr_retry(func): + Finalizers with highest priority are called first; finalizers with + the same priority will be called in reverse order of creation. ''' - Automatic retry after EINTR. + if _finalizer_registry is None: + # This function may be called after this module's globals are + # destroyed. See the _exit_function function in this module for more + # notes. + return + + if minpriority is None: + f = lambda p : p[0] is not None + else: + f = lambda p : p[0] is not None and p[0] >= minpriority + + # Careful: _finalizer_registry may be mutated while this function + # is running (either by a GC run or by another thread). + + # list(_finalizer_registry) should be atomic, while + # list(_finalizer_registry.items()) is not. + keys = [key for key in list(_finalizer_registry) if f(key)] + keys.sort(reverse=True) + + for key in keys: + finalizer = _finalizer_registry.get(key) + # key may have been removed from the registry + if finalizer is not None: + sub_debug('calling %s', finalizer) + try: + finalizer() + except Exception: + import traceback + traceback.print_exc() + + if minpriority is None: + _finalizer_registry.clear() + +# +# Clean up on exit +# + +def is_exiting(): + ''' + Returns true if the process is shutting down ''' + return _exiting or _exiting is None + +_exiting = False + +def _exit_function(info=info, debug=debug, _run_finalizers=_run_finalizers, + active_children=process.active_children, + current_process=process.current_process): + # We hold on to references to functions in the arglist due to the + # situation described below, where this function is called after this + # module's globals are destroyed. + + global _exiting + + if not _exiting: + _exiting = True + + info('process shutting down') + debug('running all "atexit" finalizers with priority >= 0') + _run_finalizers(0) + + if current_process() is not None: + # We check if the current process is None here because if + # it's None, any call to ``active_children()`` will raise + # an AttributeError (active_children winds up trying to + # get attributes from util._current_process). One + # situation where this can happen is if someone has + # manipulated sys.modules, causing this module to be + # garbage collected. The destructor for the module type + # then replaces all values in the module dict with None. + # For instance, after setuptools runs a test it replaces + # sys.modules with a copy created earlier. See issues + # #9775 and #15881. Also related: #4106, #9205, and + # #9207. + + for p in active_children(): + if p.daemon: + info('calling terminate() for daemon %s', p.name) + p._popen.terminate() + + for p in active_children(): + info('calling join() for process %s', p.name) + p.join() + + debug('running the remaining "atexit" finalizers') + _run_finalizers() + +atexit.register(_exit_function) - @functools.wraps(func) - def wrapped(*args, **kwargs): - while 1: - try: - return func(*args, **kwargs) - except OSError as exc: - if get_errno(exc) != errno.EINTR: - raise - return wrapped +# +# Some fork aware types +# + +class ForkAwareThreadLock(object): + def __init__(self): + self._lock = threading.Lock() + self.acquire = self._lock.acquire + self.release = self._lock.release + register_after_fork(self, ForkAwareThreadLock._at_fork_reinit) + + def _at_fork_reinit(self): + self._lock._at_fork_reinit() + + def __enter__(self): + return self._lock.__enter__() + + def __exit__(self, *args): + return self._lock.__exit__(*args) + + +class ForkAwareLocal(threading.local): + def __init__(self): + register_after_fork(self, lambda obj : obj.__dict__.clear()) + def __reduce__(self): + return type(self), () + +# +# Close fds except those specified +# + +try: + MAXFD = os.sysconf("SC_OPEN_MAX") +except Exception: + MAXFD = 256 + +def close_all_fds_except(fds): + fds = list(fds) + [-1, MAXFD] + fds.sort() + assert fds[-1] == MAXFD, 'fd too large' + for i in range(len(fds) - 1): + os.closerange(fds[i]+1, fds[i+1]) +# +# Close sys.stdin and replace stdin with os.devnull +# + +def _close_stdin(): + if sys.stdin is None: + return + + try: + sys.stdin.close() + except (OSError, ValueError): + pass + + try: + fd = os.open(os.devnull, os.O_RDONLY) + try: + sys.stdin = open(fd, closefd=False) + except: + os.close(fd) + raise + except (OSError, ValueError): + pass + +# +# Flush standard streams, if any +# + +def _flush_std_streams(): + try: + sys.stdout.flush() + except (AttributeError, ValueError): + pass + try: + sys.stderr.flush() + except (AttributeError, ValueError): + pass + +# +# Start a program with only specified fds kept open +# + +def spawnv_passfds(path, args, passfds): + import _posixsubprocess + passfds = tuple(sorted(map(int, passfds))) + errpipe_read, errpipe_write = os.pipe() + try: + return _posixsubprocess.fork_exec( + args, [os.fsencode(path)], True, passfds, None, None, + -1, -1, -1, -1, -1, -1, errpipe_read, errpipe_write, + False, False, None, None, None, -1, None) + finally: + os.close(errpipe_read) + os.close(errpipe_write) + + +def close_fds(*fds): + """Close each file descriptor given as an argument""" + for fd in fds: + os.close(fd) + + +def _cleanup_tests(): + """Cleanup multiprocessing resources when multiprocessing tests + completed.""" + + from test import support + + # cleanup multiprocessing + process._cleanup() + + # Stop the ForkServer process if it's running + from multiprocessing import forkserver + forkserver._forkserver._stop() + + # Stop the ResourceTracker process if it's running + from multiprocessing import resource_tracker + resource_tracker._resource_tracker._stop() + + # bpo-37421: Explicitly call _run_finalizers() to remove immediately + # temporary directories created by multiprocessing.util.get_temp_dir(). + _run_finalizers() + support.gc_collect() + + support.reap_children()