|
1 | 1 | from __future__ import print_function
|
2 | 2 |
|
3 |
| -from itertools import chain |
4 |
| -import multiprocessing |
5 |
| -import os |
6 |
| -import signal |
7 | 3 | import socket
|
8 |
| -import sys |
9 |
| -import traceback |
10 | 4 | import unittest
|
11 | 5 |
|
12 | 6 | import six
|
@@ -213,70 +207,19 @@ def test_socket_error(self):
|
213 | 207 |
|
214 | 208 | def test_exception_handling(self):
|
215 | 209 | """Tests closing socket when custom exception raised"""
|
216 |
| - queue = multiprocessing.Queue() |
217 |
| - process = multiprocessing.Process(target=worker, args=(self.mc, queue)) |
218 |
| - process.start() |
219 |
| - if queue.get() != 'loop started': |
220 |
| - raise ValueError( |
221 |
| - 'Expected "loop started" message from the child process' |
222 |
| - ) |
| 210 | + class CustomException(Exception): |
| 211 | + pass |
223 | 212 |
|
224 |
| - # maximum test duration is 0.5 second |
225 |
| - num_iters = 50 |
226 |
| - timeout = 0.01 |
227 |
| - for i in range(num_iters): |
228 |
| - os.kill(process.pid, signal.SIGUSR1) |
| 213 | + self.mc.set('error', 1) |
| 214 | + with patch.object(self.mc, '_recv_value', |
| 215 | + Mock(side_effect=CustomException('custom error'))): |
229 | 216 | try:
|
230 |
| - exc = WorkerError(*queue.get(timeout=timeout)) |
231 |
| - raise exc |
232 |
| - except six.moves.queue.Empty: |
| 217 | + self.mc.get('error') |
| 218 | + except CustomException: |
233 | 219 | pass
|
234 |
| - if not process.is_alive(): |
235 |
| - break |
236 |
| - |
237 |
| - if process.is_alive(): |
238 |
| - os.kill(process.pid, signal.SIGTERM) |
239 |
| - process.join() |
240 |
| - |
241 |
| - |
242 |
| -class SignalException(Exception): |
243 |
| - pass |
244 |
| - |
245 |
| - |
246 |
| -def sighandler(signum, frame): |
247 |
| - raise SignalException() |
248 |
| - |
249 |
| - |
250 |
| -class WorkerError(Exception): |
251 |
| - def __init__(self, exc, assert_tb, signal_tb=None): |
252 |
| - super(WorkerError, self).__init__( |
253 |
| - ''.join(chain(assert_tb, signal_tb or [])) |
254 |
| - ) |
255 |
| - self.cause = exc |
256 |
| - |
257 |
| - |
258 |
| -def worker(mc, queue): |
259 |
| - signal.signal(signal.SIGUSR1, sighandler) |
260 |
| - |
261 |
| - signal_tb = None |
262 |
| - for i in range(100000): |
263 |
| - if i == 0: |
264 |
| - queue.put('loop started') |
265 |
| - try: |
266 |
| - k = str(i) |
267 |
| - mc.set(k, i) |
268 |
| - # This loop is just to increase chance to get previous value |
269 |
| - # for clarity |
270 |
| - for j in range(10): |
271 |
| - mc.get(str(i-1)) |
272 |
| - res = mc.get(k) |
273 |
| - assert res == i, 'Expected {} but was {}'.format(i, res) |
274 |
| - except AssertionError as e: |
275 |
| - assert_tb = traceback.format_exception(*sys.exc_info()) |
276 |
| - queue.put((e, assert_tb, signal_tb)) |
277 |
| - break |
278 |
| - except SignalException as e: |
279 |
| - signal_tb = traceback.format_exception(*sys.exc_info()) |
| 220 | + self.assertIs(self.mc.servers[0].socket, None) |
| 221 | + self.assertEqual(self.mc.set('error', 2), True) |
| 222 | + self.assertEqual(self.mc.get('error'), 2) |
280 | 223 |
|
281 | 224 |
|
282 | 225 | if __name__ == '__main__':
|
|
0 commit comments