|
1 | 1 | import gzip |
2 | 2 | from http.server import BaseHTTPRequestHandler, HTTPServer |
3 | 3 | import os |
| 4 | +import ssl |
4 | 5 | import threading |
5 | 6 | import time |
6 | 7 | import unittest |
| 8 | +import urllib |
7 | 9 |
|
8 | 10 | import pytest |
9 | 11 |
|
|
16 | 18 | from prometheus_client.core import GaugeHistogramMetricFamily, Timestamp |
17 | 19 | from prometheus_client.exposition import ( |
18 | 20 | basic_auth_handler, choose_encoder, default_handler, MetricsHandler, |
19 | | - passthrough_redirect_handler, tls_auth_handler, |
| 21 | + passthrough_redirect_handler, start_wsgi_server, tls_auth_handler, |
20 | 22 | ) |
21 | 23 | import prometheus_client.openmetrics.exposition as openmetrics |
22 | 24 |
|
@@ -633,6 +635,148 @@ def test_prom_no_version(self): |
633 | 635 | self.assert_is_prom(exp) |
634 | 636 |
|
635 | 637 |
|
| 638 | +class TestWsgiTLS(unittest.TestCase): |
| 639 | + def setUp(self): |
| 640 | + self.certs_dir = os.path.join( |
| 641 | + os.path.dirname(os.path.realpath(__file__)), 'certs' |
| 642 | + ) |
| 643 | + self.httpd = None |
| 644 | + self.t = None |
| 645 | + |
| 646 | + def tearDown(self): |
| 647 | + if self.httpd: |
| 648 | + self.httpd.shutdown() |
| 649 | + self.httpd.server_close() |
| 650 | + self.t.join() |
| 651 | + |
| 652 | + def _assert_tls_connection( |
| 653 | + self, |
| 654 | + server_kwargs, |
| 655 | + use_server_tls=True, |
| 656 | + client_tls_kwargs=None, |
| 657 | + request_tls_version=ssl.TLSVersion.TLSv1_3, |
| 658 | + expect_exception=None |
| 659 | + ): |
| 660 | + self.httpd, self.t = start_wsgi_server(port=0, **server_kwargs) |
| 661 | + port = self.httpd.server_address[1] |
| 662 | + |
| 663 | + if use_server_tls: |
| 664 | + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) |
| 665 | + ctx.minimum_version = request_tls_version |
| 666 | + ctx.maximum_version = request_tls_version |
| 667 | + ctx.load_verify_locations( |
| 668 | + os.path.join(self.certs_dir, "server-ca.pem") |
| 669 | + ) |
| 670 | + |
| 671 | + if client_tls_kwargs is not None: |
| 672 | + ctx.load_cert_chain(**client_tls_kwargs) |
| 673 | + |
| 674 | + url = f"https://localhost:{port}/metrics" |
| 675 | + else: |
| 676 | + ctx = None |
| 677 | + url = f"http://localhost:{port}/metrics" |
| 678 | + |
| 679 | + if expect_exception is not None: |
| 680 | + self.assertRaises( |
| 681 | + expect_exception, |
| 682 | + urllib.request.urlopen, |
| 683 | + url, |
| 684 | + context=ctx |
| 685 | + ) |
| 686 | + else: |
| 687 | + response = urllib.request.urlopen(url, context=ctx) |
| 688 | + self.assertEqual(response.status, 200) |
| 689 | + |
| 690 | + def test_tls_disabled(self): |
| 691 | + self._assert_tls_connection(server_kwargs={}, use_server_tls=False) |
| 692 | + |
| 693 | + def test_tls_enabled(self): |
| 694 | + server_kwargs = { |
| 695 | + "certfile": os.path.join(self.certs_dir, "server-cert.pem"), |
| 696 | + "keyfile": os.path.join(self.certs_dir, "server-key.pem"), |
| 697 | + } |
| 698 | + self._assert_tls_connection(server_kwargs) |
| 699 | + |
| 700 | + def test_tls_untrusted_server_cert_raises(self): |
| 701 | + server_kwargs = { |
| 702 | + "certfile": os.path.join(self.certs_dir, "cert.pem"), |
| 703 | + "keyfile": os.path.join(self.certs_dir, "key.pem"), |
| 704 | + } |
| 705 | + self._assert_tls_connection( |
| 706 | + server_kwargs, |
| 707 | + expect_exception=urllib.error.URLError |
| 708 | + ) |
| 709 | + |
| 710 | + def test_tls_versions_configured_correctly(self): |
| 711 | + server_kwargs = { |
| 712 | + "certfile": os.path.join(self.certs_dir, "server-cert.pem"), |
| 713 | + "keyfile": os.path.join(self.certs_dir, "server-key.pem"), |
| 714 | + "tls_min_version": ssl.TLSVersion.TLSv1_2, |
| 715 | + "tls_max_version": ssl.TLSVersion.TLSv1_3, |
| 716 | + } |
| 717 | + self._assert_tls_connection( |
| 718 | + server_kwargs, |
| 719 | + request_tls_version=ssl.TLSVersion.TLSv1_2 |
| 720 | + ) |
| 721 | + |
| 722 | + def test_tls_using_lower_version_than_min_raises(self): |
| 723 | + server_kwargs = { |
| 724 | + "certfile": os.path.join(self.certs_dir, "server-cert.pem"), |
| 725 | + "keyfile": os.path.join(self.certs_dir, "server-key.pem"), |
| 726 | + "tls_min_version": ssl.TLSVersion.TLSv1_3, |
| 727 | + } |
| 728 | + self._assert_tls_connection( |
| 729 | + server_kwargs, |
| 730 | + request_tls_version=ssl.TLSVersion.TLSv1_2, |
| 731 | + expect_exception=urllib.error.URLError |
| 732 | + ) |
| 733 | + |
| 734 | + def test_tls_using_higher_version_than_max_raises(self): |
| 735 | + server_kwargs = { |
| 736 | + "certfile": os.path.join(self.certs_dir, "server-cert.pem"), |
| 737 | + "keyfile": os.path.join(self.certs_dir, "server-key.pem"), |
| 738 | + "tls_max_version": ssl.TLSVersion.TLSv1_2, |
| 739 | + } |
| 740 | + self._assert_tls_connection( |
| 741 | + server_kwargs, |
| 742 | + request_tls_version=ssl.TLSVersion.TLSv1_3, |
| 743 | + expect_exception=urllib.error.URLError |
| 744 | + ) |
| 745 | + |
| 746 | + def test_mtls_enabled(self): |
| 747 | + server_kwargs = { |
| 748 | + "certfile": os.path.join(self.certs_dir, "server-cert.pem"), |
| 749 | + "keyfile": os.path.join(self.certs_dir, "server-key.pem"), |
| 750 | + "client_auth_required": True, |
| 751 | + "client_cafile": os.path.join(self.certs_dir, "server-ca.pem"), |
| 752 | + } |
| 753 | + client_tls_kwargs = { |
| 754 | + "certfile": os.path.join(self.certs_dir, "client-cert.pem"), |
| 755 | + "keyfile": os.path.join(self.certs_dir, "client-key.pem") |
| 756 | + } |
| 757 | + self._assert_tls_connection( |
| 758 | + server_kwargs, |
| 759 | + client_tls_kwargs=client_tls_kwargs |
| 760 | + ) |
| 761 | + |
| 762 | + def test_mtls_untrusted_client_cert_raises(self): |
| 763 | + server_kwargs = { |
| 764 | + "certfile": os.path.join(self.certs_dir, "server-cert.pem"), |
| 765 | + "keyfile": os.path.join(self.certs_dir, "server-key.pem"), |
| 766 | + "client_auth_required": True, |
| 767 | + "client_cafile": os.path.join(self.certs_dir, "server-cert.pem"), |
| 768 | + } |
| 769 | + client_tls_kwargs = { |
| 770 | + "certfile": os.path.join(self.certs_dir, "cert.pem"), |
| 771 | + "keyfile": os.path.join(self.certs_dir, "key.pem") |
| 772 | + } |
| 773 | + self._assert_tls_connection( |
| 774 | + server_kwargs, |
| 775 | + client_tls_kwargs=client_tls_kwargs, |
| 776 | + expect_exception=ssl.SSLError |
| 777 | + ) |
| 778 | + |
| 779 | + |
636 | 780 | @pytest.mark.parametrize("scenario", [ |
637 | 781 | { |
638 | 782 | "name": "empty string", |
|
0 commit comments