|
1 | 1 | from __future__ import print_function |
2 | 2 |
|
3 | | -from itertools import chain |
4 | | -import multiprocessing |
5 | | -import os |
6 | | -import signal |
7 | 3 | import socket |
8 | | -import sys |
9 | | -import traceback |
10 | 4 | import unittest |
11 | 5 |
|
12 | 6 | import six |
@@ -212,70 +206,19 @@ def test_socket_error(self): |
212 | 206 |
|
213 | 207 | def test_exception_handling(self): |
214 | 208 | """Tests closing socket when custom exception raised""" |
215 | | - queue = multiprocessing.Queue() |
216 | | - process = multiprocessing.Process(target=worker, args=(self.mc, queue)) |
217 | | - process.start() |
218 | | - if queue.get() != 'loop started': |
219 | | - raise ValueError( |
220 | | - 'Expected "loop started" message from the child process' |
221 | | - ) |
| 209 | + class CustomException(Exception): |
| 210 | + pass |
222 | 211 |
|
223 | | - # maximum test duration is 0.5 second |
224 | | - num_iters = 50 |
225 | | - timeout = 0.01 |
226 | | - for i in range(num_iters): |
227 | | - os.kill(process.pid, signal.SIGUSR1) |
| 212 | + self.mc.set('error', 1) |
| 213 | + with patch.object(self.mc, '_recv_value', |
| 214 | + Mock(side_effect=CustomException('custom error'))): |
228 | 215 | try: |
229 | | - exc = WorkerError(*queue.get(timeout=timeout)) |
230 | | - raise exc |
231 | | - except six.moves.queue.Empty: |
| 216 | + self.mc.get('error') |
| 217 | + except CustomException: |
232 | 218 | pass |
233 | | - if not process.is_alive(): |
234 | | - break |
235 | | - |
236 | | - if process.is_alive(): |
237 | | - os.kill(process.pid, signal.SIGTERM) |
238 | | - process.join() |
239 | | - |
240 | | - |
241 | | -class SignalException(Exception): |
242 | | - pass |
243 | | - |
244 | | - |
245 | | -def sighandler(signum, frame): |
246 | | - raise SignalException() |
247 | | - |
248 | | - |
249 | | -class WorkerError(Exception): |
250 | | - def __init__(self, exc, assert_tb, signal_tb=None): |
251 | | - super(WorkerError, self).__init__( |
252 | | - ''.join(chain(assert_tb, signal_tb or [])) |
253 | | - ) |
254 | | - self.cause = exc |
255 | | - |
256 | | - |
257 | | -def worker(mc, queue): |
258 | | - signal.signal(signal.SIGUSR1, sighandler) |
259 | | - |
260 | | - signal_tb = None |
261 | | - for i in range(100000): |
262 | | - if i == 0: |
263 | | - queue.put('loop started') |
264 | | - try: |
265 | | - k = str(i) |
266 | | - mc.set(k, i) |
267 | | - # This loop is just to increase chance to get previous value |
268 | | - # for clarity |
269 | | - for j in range(10): |
270 | | - mc.get(str(i-1)) |
271 | | - res = mc.get(k) |
272 | | - assert res == i, 'Expected {} but was {}'.format(i, res) |
273 | | - except AssertionError as e: |
274 | | - assert_tb = traceback.format_exception(*sys.exc_info()) |
275 | | - queue.put((e, assert_tb, signal_tb)) |
276 | | - break |
277 | | - except SignalException as e: |
278 | | - signal_tb = traceback.format_exception(*sys.exc_info()) |
| 219 | + self.assertIs(self.mc.servers[0].socket, None) |
| 220 | + self.assertEqual(self.mc.set('error', 2), True) |
| 221 | + self.assertEqual(self.mc.get('error'), 2) |
279 | 222 |
|
280 | 223 |
|
281 | 224 | if __name__ == '__main__': |
|
0 commit comments