import re
import ssl
import time
import subprocess
import unittest
import unit
class TestUnitTLS(unit.TestUnitApplicationTLS):
def setUpClass():
unit.TestUnit().check_modules('python', 'openssl')
def findall(self, pattern):
with open(self.testdir + '/unit.log', 'r', errors='ignore') as f:
return re.findall(pattern, f.read())
def wait_for_record(self, pattern):
for i in range(50):
with open(self.testdir + '/unit.log', 'r', errors='ignore') as f:
if re.search(pattern, f.read()) is not None:
break
time.sleep(0.1)
def openssl_date_to_sec_epoch(self, date):
return self.date_to_sec_epoch(date, '%b %d %H:%M:%S %Y %Z')
def add_tls(self, application='empty', cert='default', port=7080):
self.conf({
"application": application,
"tls": {
"certificate": cert
}
}, 'listeners/*:' + str(port))
def remove_tls(self, application='empty', port=7080):
self.conf({
"application": application
}, 'listeners/*:' + str(port))
def test_tls_listener_option_add(self):
self.load('empty')
self.certificate()
self.add_tls()
self.assertEqual(self.get_ssl()['status'], 200, 'add listener option')
def test_tls_listener_option_remove(self):
self.load('empty')
self.certificate()
self.add_tls()
self.get_ssl()
self.remove_tls()
self.assertEqual(self.get()['status'], 200, 'remove listener option')
def test_tls_certificate_remove(self):
self.load('empty')
self.certificate()
self.assertIn('success', self.conf_delete('/certificates/default'),
'remove certificate')
def test_tls_certificate_remove_used(self):
self.load('empty')
self.certificate()
self.add_tls()
self.assertIn('error', self.conf_delete('/certificates/default'),
'remove certificate')
def test_tls_certificate_remove_nonexisting(self):
self.load('empty')
self.certificate()
self.add_tls()
self.assertIn('error', self.conf_delete('/certificates/blah'),
'remove nonexistings certificate')
@unittest.expectedFailure
def test_tls_certificate_update(self):
self.load('empty')
self.certificate()
self.add_tls()
cert_old = self.get_server_certificate()
self.certificate()
self.assertNotEqual(cert_old, self.get_server_certificate(),
'update certificate')
@unittest.expectedFailure
def test_tls_certificate_key_incorrect(self):
self.load('empty')
self.certificate('first', False)
self.certificate('second', False)
self.assertIn('error', self.certificate_load('first', 'second'),
'key incorrect')
def test_tls_certificate_change(self):
self.load('empty')
self.certificate()
self.certificate('new')
self.add_tls()
cert_old = self.get_server_certificate()
self.add_tls(cert='new')
self.assertNotEqual(cert_old, self.get_server_certificate(),
'change certificate')
def test_tls_certificate_key_rsa(self):
self.load('empty')
self.certificate()
self.assertEqual(self.conf_get('/certificates/default/key'),
'RSA (1024 bits)', 'certificate key rsa')
def test_tls_certificate_key_ec(self):
self.load('empty')
subprocess.call(['openssl', 'ecparam', '-noout', '-genkey',
'-out', self.testdir + '/ec.key',
'-name', 'prime256v1'])
subprocess.call(['openssl', 'req', '-x509', '-new',
'-config', self.testdir + '/openssl.conf',
'-key', self.testdir + '/ec.key', '-subj', '/CN=ec/',
'-out', self.testdir + '/ec.crt'])
self.certificate_load('ec')
self.assertEqual(self.conf_get('/certificates/ec/key'), 'ECDH',
'certificate key ec')
def test_tls_certificate_chain_options(self):
self.load('empty')
self.certificate()
chain = self.conf_get('/certificates/default/chain')
self.assertEqual(len(chain), 1, 'certificate chain length')
cert = chain[0]
self.assertEqual(cert['subject']['common_name'], 'default',
'certificate subject common name')
self.assertEqual(cert['issuer']['common_name'], 'default',
'certificate issuer common name')
self.assertLess(abs(self.sec_epoch() -
self.openssl_date_to_sec_epoch(cert['validity']['since'])), 5,
'certificate validity since')
self.assertEqual(
self.openssl_date_to_sec_epoch(cert['validity']['until']) -
self.openssl_date_to_sec_epoch(cert['validity']['since']), 2592000,
'certificate validity until')
def test_tls_certificate_chain(self):
self.load('empty')
self.certificate('root', False)
subprocess.call(['openssl', 'req', '-new', '-config',
self.testdir + '/openssl.conf', '-subj', '/CN=int/',
'-out', self.testdir + '/int.csr',
'-keyout', self.testdir + '/int.key'])
subprocess.call(['openssl', 'req', '-new', '-config',
self.testdir + '/openssl.conf', '-subj', '/CN=end/',
'-out', self.testdir + '/end.csr',
'-keyout', self.testdir + '/end.key'])
with open(self.testdir + '/ca.conf', 'w') as f:
f.write("""[ ca ]
default_ca = myca
[ myca ]
new_certs_dir = %(dir)s
database = %(database)s
default_md = sha1
policy = myca_policy
serial = %(certserial)s
default_days = 1
x509_extensions = myca_extensions
[ myca_policy ]
commonName = supplied
[ myca_extensions ]
basicConstraints = critical,CA:TRUE""" % {
'dir': self.testdir,
'database': self.testdir + '/certindex',
'certserial': self.testdir + '/certserial'
})
with open(self.testdir + '/certserial', 'w') as f:
f.write('1000')
with open(self.testdir + '/certindex', 'w') as f:
f.write('')
subprocess.call(['openssl', 'ca', '-batch',
'-config', self.testdir + '/ca.conf',
'-keyfile', self.testdir + '/root.key',
'-cert', self.testdir + '/root.crt',
'-subj', '/CN=int/',
'-in', self.testdir + '/int.csr',
'-out', self.testdir + '/int.crt'])
subprocess.call(['openssl', 'ca', '-batch',
'-config', self.testdir + '/ca.conf',
'-keyfile', self.testdir + '/int.key',
'-cert', self.testdir + '/int.crt',
'-subj', '/CN=end/',
'-in', self.testdir + '/end.csr',
'-out', self.testdir + '/end.crt'])
with open(self.testdir + '/end-int.crt', 'wb') as crt, \
open(self.testdir + '/end.crt', 'rb') as end, \
open(self.testdir + '/int.crt', 'rb') as int:
crt.write(end.read() + int.read())
self.context = ssl.create_default_context()
self.context.check_hostname = False
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.load_verify_locations(self.testdir + '/root.crt')
# incomplete chain
self.assertIn('success', self.certificate_load('end', 'end'),
'certificate chain end upload')
chain = self.conf_get('/certificates/end/chain')
self.assertEqual(len(chain), 1, 'certificate chain end length')
self.assertEqual(chain[0]['subject']['common_name'], 'end',
'certificate chain end subject common name')
self.assertEqual(chain[0]['issuer']['common_name'], 'int',
'certificate chain end issuer common name')
self.add_tls(cert='end')
try:
resp = self.get_ssl()
except ssl.SSLError:
resp = None
self.assertEqual(resp, None, 'certificate chain incomplete chain')
# intermediate
self.assertIn('success', self.certificate_load('int', 'int'),
'certificate chain int upload')
chain = self.conf_get('/certificates/int/chain')
self.assertEqual(len(chain), 1, 'certificate chain int length')
self.assertEqual(chain[0]['subject']['common_name'], 'int',
'certificate chain int subject common name')
self.assertEqual(chain[0]['issuer']['common_name'], 'root',
'certificate chain int issuer common name')
self.add_tls(cert='int')
self.assertEqual(self.get_ssl()['status'], 200,
'certificate chain intermediate')
# intermediate server
self.assertIn('success', self.certificate_load('end-int', 'end'),
'certificate chain end-int upload')
chain = self.conf_get('/certificates/end-int/chain')
self.assertEqual(len(chain), 2, 'certificate chain end-int length')
self.assertEqual(chain[0]['subject']['common_name'], 'end',
'certificate chain end-int int subject common name')
self.assertEqual(chain[0]['issuer']['common_name'], 'int',
'certificate chain end-int int issuer common name')
self.assertEqual(chain[1]['subject']['common_name'], 'int',
'certificate chain end-int end subject common name')
self.assertEqual(chain[1]['issuer']['common_name'], 'root',
'certificate chain end-int end issuer common name')
self.add_tls(cert='end-int')
self.assertEqual(self.get_ssl()['status'], 200,
'certificate chain intermediate server')
def test_tls_reconfigure(self):
self.load('empty')
self.certificate()
(resp, sock) = self.http(b"""GET / HTTP/1.1
""", start=True, raw=True, no_recv=True)
self.add_tls()
resp = self.http(b"""Host: localhost
Connection: close
""", sock=sock, raw=True)
self.assertEqual(resp['status'], 200, 'update status')
self.assertEqual(self.get_ssl()['status'], 200, 'update tls status')
def test_tls_keepalive(self):
self.load('mirror')
self.certificate()
self.add_tls(application='mirror')
(resp, sock) = self.post_ssl(headers={
'Connection': 'keep-alive',
'Content-Type': 'text/html',
'Host': 'localhost'
}, start=True, body='0123456789')
self.assertEqual(resp['body'], '0123456789', 'keepalive 1')
resp = self.post_ssl(headers={
'Connection': 'close',
'Content-Type': 'text/html',
'Host': 'localhost'
}, sock=sock, body='0123456789')
self.assertEqual(resp['body'], '0123456789', 'keepalive 2')
@unittest.expectedFailure
def test_tls_keepalive_certificate_remove(self):
self.load('empty')
self.certificate()
self.add_tls()
(resp, sock) = self.get_ssl(headers={
'Connection': 'keep-alive',
'Host': 'localhost'
}, start=True)
self.conf({
"application": "empty"
}, 'listeners/*:7080')
self.conf_delete('/certificates/default')
try:
resp = self.get_ssl(headers={
'Connection': 'close',
'Host': 'localhost'
}, sock=sock)
except:
resp = None
self.assertEqual(resp, None, 'keepalive remove certificate')
@unittest.expectedFailure
def test_tls_certificates_remove_all(self):
self.load('empty')
self.certificate()
self.assertIn('success', self.conf_delete('/certificates'),
'remove all certificates')
def test_tls_application_respawn(self):
self.skip_alerts.append(r'process \d+ exited on signal 9')
self.load('mirror')
self.certificate()
self.conf('1', 'applications/mirror/processes')
self.add_tls(application='mirror')
(resp, sock) = self.post_ssl(headers={
'Connection': 'keep-alive',
'Content-Type': 'text/html',
'Host': 'localhost'
}, start=True, body='0123456789')
app_id = self.findall(r'(\d+)#\d+ "mirror" application started')[0]
subprocess.call(['kill', '-9', app_id])
self.wait_for_record(re.compile(' (?!' + app_id +
'#)(\d+)#\d+ "mirror" application started'))
resp = self.post_ssl(headers={
'Connection': 'close',
'Content-Type': 'text/html',
'Host': 'localhost'
}, sock=sock, body='0123456789')
self.assertEqual(resp['status'], 200, 'application respawn status')
self.assertEqual(resp['body'], '0123456789', 'application respawn body')
if __name__ == '__main__':
unittest.main()