diff options
-rw-r--r-- | test/test_tls.py | 417 | ||||
-rw-r--r-- | test/unit.py | 92 |
2 files changed, 505 insertions, 4 deletions
diff --git a/test/test_tls.py b/test/test_tls.py new file mode 100644 index 00000000..26bedcb7 --- /dev/null +++ b/test/test_tls.py @@ -0,0 +1,417 @@ +import re +import ssl +import time +import subprocess +import unittest +import unit + +class TestUnitTLS(unit.TestUnitApplicationTLS): + + def setUpClass(): + unit.TestUnit().check_modules('python', 'openssl') + + def findall(self, pattern): + with open(self.testdir + '/unit.log', 'r', errors='ignore') as f: + return re.findall(pattern, f.read()) + + def wait_for_record(self, pattern): + for i in range(50): + with open(self.testdir + '/unit.log', 'r', errors='ignore') as f: + if re.search(pattern, f.read()) is not None: + break + + time.sleep(0.1) + + def openssl_date_to_sec_epoch(self, date): + return self.date_to_sec_epoch(date, '%b %d %H:%M:%S %Y %Z') + + def add_tls(self, application='empty', cert='default', port=7080): + self.conf({ + "application": application, + "tls": { + "certificate": cert + } + }, 'listeners/*:' + str(port)) + + def remove_tls(self, application='empty', port=7080): + self.conf({ + "application": application + }, 'listeners/*:' + str(port)) + + def test_tls_listener_option_add(self): + self.load('empty') + + self.certificate() + + self.add_tls() + + self.assertEqual(self.get_ssl()['status'], 200, 'add listener option') + + def test_tls_listener_option_remove(self): + self.load('empty') + + self.certificate() + + self.add_tls() + + self.get_ssl() + + self.remove_tls() + + self.assertEqual(self.get()['status'], 200, 'remove listener option') + + def test_tls_certificate_remove(self): + self.load('empty') + + self.certificate() + + self.assertIn('success', self.conf_delete('/certificates/default'), + 'remove certificate') + + def test_tls_certificate_remove_used(self): + self.load('empty') + + self.certificate() + + self.add_tls() + + self.assertIn('error', self.conf_delete('/certificates/default'), + 'remove certificate') + + def test_tls_certificate_remove_nonexisting(self): + self.load('empty') + + self.certificate() + + self.add_tls() + + self.assertIn('error', self.conf_delete('/certificates/blah'), + 'remove nonexistings certificate') + + @unittest.expectedFailure + def test_tls_certificate_update(self): + self.load('empty') + + self.certificate() + + self.add_tls() + + cert_old = self.get_server_certificate() + + self.certificate() + + self.assertNotEqual(cert_old, self.get_server_certificate(), + 'update certificate') + + @unittest.expectedFailure + def test_tls_certificate_key_incorrect(self): + self.load('empty') + + self.certificate('first', False) + self.certificate('second', False) + + self.assertIn('error', self.certificate_load('first', 'second'), + 'key incorrect') + + def test_tls_certificate_change(self): + self.load('empty') + + self.certificate() + self.certificate('new') + + self.add_tls() + + cert_old = self.get_server_certificate() + + self.add_tls(cert='new') + + self.assertNotEqual(cert_old, self.get_server_certificate(), + 'change certificate') + + def test_tls_certificate_key_rsa(self): + self.load('empty') + + self.certificate() + + self.assertEqual(self.conf_get('/certificates/default/key'), + 'RSA (1024 bits)', 'certificate key rsa') + + def test_tls_certificate_key_ec(self): + subprocess.call(['openssl', 'ecparam', '-noout', '-genkey', + '-out', self.testdir + '/ec.key', + '-name', 'prime256v1']) + + subprocess.call(['openssl', 'req', '-x509', '-new', + '-key', self.testdir + '/ec.key', '-subj', '/CN=ec/', + '-out', self.testdir + '/ec.crt']) + + self.certificate_load('ec') + + self.assertEqual(self.conf_get('/certificates/ec/key'), 'ECDH', + 'certificate key ec') + + def test_tls_certificate_chain_options(self): + self.load('empty') + + self.certificate() + + chain = self.conf_get('/certificates/default/chain') + + self.assertEqual(len(chain), 1, 'certificate chain length') + + cert = chain[0] + + self.assertEqual(cert['subject']['common_name'], 'default', + 'certificate subject common name') + self.assertEqual(cert['issuer']['common_name'], 'default', + 'certificate issuer common name') + + self.assertLess(abs(self.sec_epoch() - + self.openssl_date_to_sec_epoch(cert['validity']['since'])), 5, + 'certificate validity since') + self.assertEqual( + self.openssl_date_to_sec_epoch(cert['validity']['until']) - + self.openssl_date_to_sec_epoch(cert['validity']['since']), 2592000, + 'certificate validity until') + + def test_tls_certificate_chain(self): + self.load('empty') + + self.certificate('root', False) + + subprocess.call(['openssl', 'req', '-new', '-config', + self.testdir + '/openssl.conf', '-subj', '/CN=int/', + '-out', self.testdir + '/int.csr', + '-keyout', self.testdir + '/int.key']) + + subprocess.call(['openssl', 'req', '-new', '-config', + self.testdir + '/openssl.conf', '-subj', '/CN=end/', + '-out', self.testdir + '/end.csr', + '-keyout', self.testdir + '/end.key']) + + with open(self.testdir + '/ca.conf', 'w') as f: + f.write("""[ ca ] +default_ca = myca + +[ myca ] +new_certs_dir = %(dir)s +database = %(database)s +default_md = sha1 +policy = myca_policy +serial = %(certserial)s +default_days = 1 +x509_extensions = myca_extensions + +[ myca_policy ] +commonName = supplied + +[ myca_extensions ] +basicConstraints = critical,CA:TRUE""" % { + 'dir': self.testdir, + 'database': self.testdir + '/certindex', + 'certserial': self.testdir + '/certserial' + }) + + with open(self.testdir + '/certserial', 'w') as f: + f.write('1000') + + with open(self.testdir + '/certindex', 'w') as f: + f.write('') + + subprocess.call(['openssl', 'ca', '-batch', + '-config', self.testdir + '/ca.conf', + '-keyfile', self.testdir + '/root.key', + '-cert', self.testdir + '/root.crt', + '-subj', '/CN=int/', + '-in', self.testdir + '/int.csr', + '-out', self.testdir + '/int.crt']) + + subprocess.call(['openssl', 'ca', '-batch', + '-config', self.testdir + '/ca.conf', + '-keyfile', self.testdir + '/int.key', + '-cert', self.testdir + '/int.crt', + '-subj', '/CN=end/', + '-in', self.testdir + '/end.csr', + '-out', self.testdir + '/end.crt']) + + with open(self.testdir + '/end-int.crt', 'wb') as crt, \ + open(self.testdir + '/end.crt', 'rb') as end, \ + open(self.testdir + '/int.crt', 'rb') as int: + crt.write(end.read() + int.read()) + + self.context = ssl.create_default_context() + self.context.check_hostname = False + self.context.verify_mode = ssl.CERT_REQUIRED + self.context.load_verify_locations(self.testdir + '/root.crt') + + # incomplete chain + + self.assertIn('success', self.certificate_load('end', 'end'), + 'certificate chain end upload') + + chain = self.conf_get('/certificates/end/chain') + self.assertEqual(len(chain), 1, 'certificate chain end length') + self.assertEqual(chain[0]['subject']['common_name'], 'end', + 'certificate chain end subject common name') + self.assertEqual(chain[0]['issuer']['common_name'], 'int', + 'certificate chain end issuer common name') + + self.add_tls(cert='end') + + try: + resp = self.get_ssl() + except ssl.SSLError: + resp = None + + self.assertEqual(resp, None, 'certificate chain incomplete chain') + + # intermediate + + self.assertIn('success', self.certificate_load('int', 'int'), + 'certificate chain int upload') + + chain = self.conf_get('/certificates/int/chain') + self.assertEqual(len(chain), 1, 'certificate chain int length') + self.assertEqual(chain[0]['subject']['common_name'], 'int', + 'certificate chain int subject common name') + self.assertEqual(chain[0]['issuer']['common_name'], 'root', + 'certificate chain int issuer common name') + + self.add_tls(cert='int') + + self.assertEqual(self.get_ssl()['status'], 200, + 'certificate chain intermediate') + + # intermediate server + + self.assertIn('success', self.certificate_load('end-int', 'end'), + 'certificate chain end-int upload') + + chain = self.conf_get('/certificates/end-int/chain') + self.assertEqual(len(chain), 2, 'certificate chain end-int length') + self.assertEqual(chain[0]['subject']['common_name'], 'end', + 'certificate chain end-int int subject common name') + self.assertEqual(chain[0]['issuer']['common_name'], 'int', + 'certificate chain end-int int issuer common name') + self.assertEqual(chain[1]['subject']['common_name'], 'int', + 'certificate chain end-int end subject common name') + self.assertEqual(chain[1]['issuer']['common_name'], 'root', + 'certificate chain end-int end issuer common name') + + self.add_tls(cert='end-int') + + self.assertEqual(self.get_ssl()['status'], 200, + 'certificate chain intermediate server') + + def test_tls_reconfigure(self): + self.load('empty') + + self.certificate() + + (resp, sock) = self.http(b"""GET / HTTP/1.1 +""", start=True, raw=True, no_recv=True) + + self.add_tls() + + resp = self.http(b"""Host: localhost +Connection: close + +""", sock=sock, raw=True) + + self.assertEqual(resp['status'], 200, 'update status') + self.assertEqual(self.get_ssl()['status'], 200, 'update tls status') + + def test_tls_keepalive(self): + self.load('mirror') + + self.certificate() + + self.add_tls(application='mirror') + + (resp, sock) = self.post_ssl(headers={ + 'Connection': 'keep-alive', + 'Content-Type': 'text/html', + 'Host': 'localhost' + }, start=True, body='0123456789') + + self.assertEqual(resp['body'], '0123456789', 'keepalive 1') + + resp = self.post_ssl(headers={ + 'Connection': 'close', + 'Content-Type': 'text/html', + 'Host': 'localhost' + }, sock=sock, body='0123456789') + + self.assertEqual(resp['body'], '0123456789', 'keepalive 2') + + @unittest.expectedFailure + def test_tls_keepalive_certificate_remove(self): + self.load('empty') + + self.certificate() + + self.add_tls() + + (resp, sock) = self.get_ssl(headers={ + 'Connection': 'keep-alive', + 'Host': 'localhost' + }, start=True) + + self.conf({ + "application": "empty" + }, 'listeners/*:7080') + self.conf_delete('/certificates/default') + + try: + resp = self.get_ssl(headers={ + 'Connection': 'close', + 'Host': 'localhost' + }, sock=sock) + except: + resp = None + + self.assertEqual(resp, None, 'keepalive remove certificate') + + @unittest.expectedFailure + def test_tls_certificates_remove_all(self): + self.load('empty') + + self.certificate() + + self.assertIn('success', self.conf_delete('/certificates'), + 'remove all certificates') + + def test_tls_application_respawn(self): + self.skip_alerts.append(r'process \d+ exited on signal 9') + self.load('mirror') + + self.certificate() + + self.conf('1', 'applications/mirror/processes') + + self.add_tls(application='mirror') + + (resp, sock) = self.post_ssl(headers={ + 'Connection': 'keep-alive', + 'Content-Type': 'text/html', + 'Host': 'localhost' + }, start=True, body='0123456789') + + app_id = self.findall(r'(\d+)#\d+ "mirror" application started')[0] + + subprocess.call(['kill', '-9', app_id]) + + self.wait_for_record(re.compile(' (?!' + app_id + + '#)(\d+)#\d+ "mirror" application started')) + + resp = self.post_ssl(headers={ + 'Connection': 'close', + 'Content-Type': 'text/html', + 'Host': 'localhost' + }, sock=sock, body='0123456789') + + self.assertEqual(resp['status'], 200, 'application respawn status') + self.assertEqual(resp['body'], '0123456789', 'application respawn body') + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit.py b/test/unit.py index 58c8327d..cec62489 100644 --- a/test/unit.py +++ b/test/unit.py @@ -1,5 +1,6 @@ import os import re +import ssl import sys import json import time @@ -67,6 +68,19 @@ class TestUnit(unittest.TestCase): except: m = None + elif module == 'openssl': + try: + subprocess.check_output(['which', 'openssl']) + + output = subprocess.check_output([ + self.pardir + '/build/unitd', '--version'], + stderr=subprocess.STDOUT) + + m = re.search('--openssl', output.decode()) + + except: + m = None + else: m = re.search('module: ' + module, log) @@ -192,6 +206,7 @@ class TestUnitHTTP(TestUnit): port = 7080 if 'port' not in kwargs else kwargs['port'] url = '/' if 'url' not in kwargs else kwargs['url'] http = 'HTTP/1.0' if 'http_10' in kwargs else 'HTTP/1.1' + blocking = False if 'blocking' not in kwargs else kwargs['blocking'] headers = ({ 'Host': 'localhost', @@ -215,6 +230,9 @@ class TestUnitHTTP(TestUnit): if 'sock' not in kwargs: sock = socket.socket(sock_types[sock_type], socket.SOCK_STREAM) + if 'wrapper' in kwargs: + sock = kwargs['wrapper'](sock) + connect_args = addr if sock_type == 'unix' else (addr, port) try: sock.connect(connect_args) @@ -222,11 +240,11 @@ class TestUnitHTTP(TestUnit): sock.close() return None + sock.setblocking(blocking) + else: sock = kwargs['sock'] - sock.setblocking(False) - if 'raw' not in kwargs: req = ' '.join([start_str, url, http]) + crlf @@ -371,8 +389,8 @@ class TestUnitApplicationProto(TestUnitControl): def sec_epoch(self): return time.mktime(time.gmtime()) - def date_to_sec_epoch(self, date): - return time.mktime(time.strptime(date, '%a, %d %b %Y %H:%M:%S %Z')) + def date_to_sec_epoch(self, date, template='%a, %d %b %Y %H:%M:%S %Z'): + return time.mktime(time.strptime(date, template)) def search_in_log(self, pattern): with open(self.testdir + '/unit.log', 'r', errors='ignore') as f: @@ -484,3 +502,69 @@ class TestUnitApplicationPerl(TestUnitApplicationProto): } } }) + +class TestUnitApplicationTLS(TestUnitApplicationProto): + def __init__(self, test): + super().__init__(test) + + self.context = ssl.create_default_context() + self.context.check_hostname = False + self.context.verify_mode = ssl.CERT_NONE + + def certificate(self, name='default', load=True): + subprocess.call(['openssl', 'req', '-x509', '-new', '-config', + self.testdir + '/openssl.conf', '-subj', '/CN=' + name + '/', + '-out', self.testdir + '/' + name + '.crt', + '-keyout', self.testdir + '/' + name + '.key']) + + if load: + self.certificate_load(name) + + def certificate_load(self, crt, key=None): + if key is None: + key = crt + + with open(self.testdir + '/' + key + '.key', 'rb') as k, \ + open(self.testdir + '/' + crt + '.crt', 'rb') as c: + return self.conf(k.read() + c.read(), '/certificates/' + crt) + + def get_ssl(self, **kwargs): + return self.get(blocking=True, wrapper=self.context.wrap_socket, + **kwargs) + + def post_ssl(self, **kwargs): + return self.post(blocking=True, wrapper=self.context.wrap_socket, + **kwargs) + + def get_server_certificate(self, addr=('127.0.0.1', 7080)): + return ssl.get_server_certificate(addr) + + def load(self, script, name=None): + if name is None: + name = script + + # create default openssl configuration + + with open(self.testdir + '/openssl.conf', 'w') as f: + f.write("""[ req ] +default_bits = 1024 +encrypt_key = no +distinguished_name = req_distinguished_name +[ req_distinguished_name ]""") + + self.conf({ + "listeners": { + "*:7080": { + "application": name + } + }, + "applications": { + name: { + "type": "python", + "processes": { "spare": 0 }, + "path": self.current_dir + '/python/' + script, + "working_directory": self.current_dir + '/python/' + script, + "module": "wsgi" + } + } + }) |