diff options
Diffstat (limited to '')
-rw-r--r-- | test/unit/applications/websockets.py | 31 | ||||
-rw-r--r-- | test/unit/http.py | 20 | ||||
-rw-r--r-- | test/unit/main.py | 48 |
3 files changed, 76 insertions, 23 deletions
diff --git a/test/unit/applications/websockets.py b/test/unit/applications/websockets.py index 50ff2797..ef16f433 100644 --- a/test/unit/applications/websockets.py +++ b/test/unit/applications/websockets.py @@ -1,3 +1,4 @@ +import re import random import base64 import struct @@ -30,25 +31,37 @@ class TestApplicationWebsocket(TestApplicationProto): sha1 = hashlib.sha1((key + GUID).encode()).digest() return base64.b64encode(sha1).decode() - def upgrade(self): - key = self.key() + def upgrade(self, headers=None): + key = None - if self.preinit: - self.get() - - resp, sock = self.get( - headers={ + if headers is None: + key = self.key() + headers = { 'Host': 'localhost', 'Upgrade': 'websocket', 'Connection': 'Upgrade', 'Sec-WebSocket-Key': key, 'Sec-WebSocket-Protocol': 'chat', 'Sec-WebSocket-Version': 13, - }, - read_timeout=1, + } + + _, sock = self.get( + headers=headers, + no_recv=True, start=True, ) + resp = '' + while select.select([sock], [], [], 30)[0]: + resp += sock.recv(4096).decode() + + if ( + re.search('101 Switching Protocols', resp) + and resp[-4:] == '\r\n\r\n' + ): + resp = self._resp_to_dict(resp) + break + return (resp, sock, key) def apply_mask(self, data, mask): diff --git a/test/unit/http.py b/test/unit/http.py index 82a6bd6a..c7e3e36d 100644 --- a/test/unit/http.py +++ b/test/unit/http.py @@ -1,4 +1,5 @@ import re +import time import socket import select from unit.main import TestUnit @@ -63,7 +64,7 @@ class TestHTTP(TestUnit): if 'raw' not in kwargs: req = ' '.join([start_str, url, http]) + crlf - if body is not b'': + if body != b'': if isinstance(body, str): body = body.encode() @@ -178,3 +179,20 @@ class TestHTTP(TestUnit): headers[m.group(1)] = [headers[m.group(1)], m.group(2)] return {'status': int(status), 'headers': headers, 'body': body} + + def waitforsocket(self, port): + ret = False + + for i in range(50): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + ret = True + break + except: + sock.close() + time.sleep(0.1) + + sock.close() + + self.assertTrue(ret, 'socket connected') diff --git a/test/unit/main.py b/test/unit/main.py index 873f1815..094fdb0e 100644 --- a/test/unit/main.py +++ b/test/unit/main.py @@ -185,6 +185,8 @@ class TestUnit(unittest.TestCase): if self._started: self._stop() + self.stop_processes() + def _run(self): self.unitd = self.pardir + '/build/unitd' @@ -240,24 +242,24 @@ class TestUnit(unittest.TestCase): break time.sleep(0.1) - if os.path.exists(self.testdir + '/unit.pid'): - exit("Could not terminate unit") + self._p.join(timeout=5) - self._started = False + if self._p.is_alive(): + self._p.terminate() + self._p.join(timeout=5) - self._p.join(timeout=1) - self._terminate_process(self._p) + if self._p.is_alive(): + self.fail("Could not terminate process " + str(self._p.pid)) - def _terminate_process(self, process): - if process.is_alive(): - process.terminate() - process.join(timeout=5) + if os.path.exists(self.testdir + '/unit.pid'): + self.fail("Could not terminate unit") - if process.is_alive(): - exit("Could not terminate process " + process.pid) + self._started = False - if process.exitcode: - exit("Child process terminated with code " + str(process.exitcode)) + if self._p.exitcode: + self.fail( + "Child process terminated with code " + str(self._p.exitcode) + ) def _check_alerts(self, log): found = False @@ -287,6 +289,26 @@ class TestUnit(unittest.TestCase): if found: print('skipped.') + def run_process(self, target): + if not hasattr(self, '_processes'): + self._processes = [] + + process = Process(target=target) + process.start() + + self._processes.append(process) + + def stop_processes(self): + if not hasattr(self, '_processes'): + return + + for process in self._processes: + process.terminate() + process.join(timeout=5) + + if process.is_alive(): + self.fail('Fail to stop process') + def waitforfiles(self, *files): for i in range(50): wait = False |