|
1 | 1 | from __future__ import print_function
|
2 | 2 |
|
| 3 | +from itertools import chain |
| 4 | +import multiprocessing |
| 5 | +import os |
| 6 | +import signal |
| 7 | +import sys |
| 8 | +import traceback |
3 | 9 | import unittest
|
4 | 10 |
|
5 | 11 | import six
|
@@ -166,6 +172,73 @@ def test_disconnect_all_delete_multi(self):
|
166 | 172 | ret = self.mc.delete_multi({'keyhere': 'a', 'keythere': 'b'})
|
167 | 173 | self.assertEqual(ret, 1)
|
168 | 174 |
|
| 175 | + def test_exception_handling(self): |
| 176 | + """Tests closing socket when custom exception raised""" |
| 177 | + queue = multiprocessing.Queue() |
| 178 | + process = multiprocessing.Process(target=worker, args=(self.mc, queue)) |
| 179 | + process.start() |
| 180 | + if queue.get() != 'loop started': |
| 181 | + raise ValueError( |
| 182 | + 'Expected "loop started" message from the child process' |
| 183 | + ) |
| 184 | + |
| 185 | + # maximum test duration is 0.5 second |
| 186 | + num_iters = 50 |
| 187 | + timeout = 0.01 |
| 188 | + for i in range(num_iters): |
| 189 | + os.kill(process.pid, signal.SIGUSR1) |
| 190 | + try: |
| 191 | + exc = WorkerError(*queue.get(timeout=timeout)) |
| 192 | + raise exc |
| 193 | + except six.moves.queue.Empty: |
| 194 | + pass |
| 195 | + if not process.is_alive(): |
| 196 | + break |
| 197 | + |
| 198 | + if process.is_alive(): |
| 199 | + os.kill(process.pid, signal.SIGTERM) |
| 200 | + process.join() |
| 201 | + |
| 202 | + |
| 203 | +class SignalException(Exception): |
| 204 | + pass |
| 205 | + |
| 206 | + |
| 207 | +def sighandler(signum, frame): |
| 208 | + raise SignalException() |
| 209 | + |
| 210 | + |
| 211 | +class WorkerError(Exception): |
| 212 | + def __init__(self, exc, assert_tb, signal_tb=None): |
| 213 | + super(WorkerError, self).__init__( |
| 214 | + ''.join(chain(assert_tb, signal_tb or [])) |
| 215 | + ) |
| 216 | + self.cause = exc |
| 217 | + |
| 218 | + |
| 219 | +def worker(mc, queue): |
| 220 | + signal.signal(signal.SIGUSR1, sighandler) |
| 221 | + |
| 222 | + signal_tb = None |
| 223 | + for i in range(100000): |
| 224 | + if i == 0: |
| 225 | + queue.put('loop started') |
| 226 | + try: |
| 227 | + k = str(i) |
| 228 | + mc.set(k, i) |
| 229 | + # This loop is just to increase chance to get previous value |
| 230 | + # for clarity |
| 231 | + for j in range(10): |
| 232 | + mc.get(str(i-1)) |
| 233 | + res = mc.get(k) |
| 234 | + assert res == i, 'Expected {} but was {}'.format(i, res) |
| 235 | + except AssertionError as e: |
| 236 | + assert_tb = traceback.format_exception(*sys.exc_info()) |
| 237 | + queue.put((e, assert_tb, signal_tb)) |
| 238 | + break |
| 239 | + except SignalException as e: |
| 240 | + signal_tb = traceback.format_exception(*sys.exc_info()) |
| 241 | + |
169 | 242 |
|
170 | 243 | if __name__ == '__main__':
|
171 | 244 | unittest.main()
|
0 commit comments