diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py similarity index 79% rename from lib/charms/tls_certificates_interface/v2/tls_certificates.py rename to lib/charms/tls_certificates_interface/v3/tls_certificates.py index ff234ff..e63a77c 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2021 Canonical Ltd. +# Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. @@ -11,7 +11,7 @@ From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: @@ -36,10 +36,10 @@ Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV2, + TLSCertificatesProvidesV3, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -59,7 +59,7 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV2(self, "certificates") + self.certificates = TLSCertificatesProvidesV3(self, "certificates") self.framework.observe( self.certificates.on.certificate_request, self._on_certificate_request @@ -126,11 +126,11 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) @@ -145,7 +145,7 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV2(self, "certificates") + self.certificates = TLSCertificatesRequiresV3(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( self.on.certificates_relation_joined, self._on_certificates_relation_joined @@ -277,15 +277,15 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven import logging import uuid from contextlib import suppress +from dataclasses import dataclass from datetime import datetime, timedelta from ipaddress import IPv4Address -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.serialization import pkcs12 from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import ( CharmBase, @@ -293,21 +293,26 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven RelationBrokenEvent, RelationChangedEvent, SecretExpiredEvent, - UpdateStatusEvent, ) from ops.framework import EventBase, EventSource, Handle, Object -from ops.jujuversion import JujuVersion -from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError +from ops.model import ( + Application, + ModelError, + Relation, + RelationDataContent, + SecretNotFoundError, + Unit, +) # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 2 +LIBAPI = 3 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 23 +LIBPATCH = 0 PYDEPS = ["cryptography", "jsonschema"] @@ -422,6 +427,30 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven logger = logging.getLogger(__name__) +@dataclass +class RequirerCSR: + """This class represents a certificate signing request from an interface Requirer.""" + + relation_id: int + application_name: str + unit_name: str + csr: str + is_ca: bool + + +@dataclass +class ProviderCertificate: + """This class represents a certificate from an interface Provider.""" + + relation_id: int + application_name: str + csr: str + certificate: str + ca: str + chain: List[str] + revoked: bool + + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -884,38 +913,6 @@ def generate_certificate( return cert.public_bytes(serialization.Encoding.PEM) -def generate_pfx_package( - certificate: bytes, - private_key: bytes, - package_password: str, - private_key_password: Optional[bytes] = None, -) -> bytes: - """Generates a PFX package to contain the TLS certificate and private key. - - Args: - certificate (bytes): TLS certificate - private_key (bytes): Private key - package_password (str): Password to open the PFX package - private_key_password (bytes): Private key password - - Returns: - bytes: - """ - private_key_object = serialization.load_pem_private_key( - private_key, password=private_key_password - ) - certificate_object = x509.load_pem_x509_certificate(certificate) - name = certificate_object.subject.rfc4514_string() - pfx_bytes = pkcs12.serialize_key_and_certificates( - name=name.encode(), - cert=certificate_object, - key=private_key_object, # type: ignore[arg-type] - cas=None, - encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), - ) - return pfx_bytes - - def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, @@ -938,9 +935,11 @@ def generate_private_key( key_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption(), + encryption_algorithm=( + serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption() + ), ) return key_bytes @@ -1049,6 +1048,26 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: return True +def _relation_data_is_valid( + relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict +) -> bool: + """Checks whether relation data is valid based on json schema. + + Args: + relation (Relation): Relation object + app_or_unit (Union[Application, Unit]): Application or unit object + + Returns: + bool: Whether relation data is valid. + """ + relation_data = _load_relation_data(relation.data[app_or_unit]) + try: + validate(instance=relation_data, schema=json_schema) + return True + except exceptions.ValidationError: + return False + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1065,7 +1084,7 @@ class CertificatesRequirerCharmEvents(CharmEvents): all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV2(Object): +class TLSCertificatesProvidesV3(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" on = CertificatesProviderCharmEvents() @@ -1174,22 +1193,6 @@ def _remove_certificate( certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Uses JSON schema validator to validate relation data content. - - Args: - certificates_data (dict): Certificate data dictionary as retrieved from relation data. - - Returns: - bool: True/False depending on whether the relation data follows the json schema. - """ - try: - validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False - def revoke_all_certificates(self) -> None: """Revokes all certificates of this provider. @@ -1258,16 +1261,24 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None - ) -> Dict[str, List[Dict[str, str]]]: - """Returns a dictionary of issued certificates. + ) -> List[ProviderCertificate]: + """Returns a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] - It returns certificates from all relations if relation_id is not specified. - Certificates are returned per application name and CSR. + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Returns a List of issued certificates. Returns: - dict: Certificates per application name. + List: List of ProviderCertificate objects """ - certificates: Dict[str, List[Dict[str, str]]] = {} + certificates: List[ProviderCertificate] = [] relations = ( [ relation @@ -1278,19 +1289,22 @@ def get_issued_certificates( else self.model.relations.get(self.relationship_name, []) ) for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) - - certificates[relation.app.name] = [] # type: ignore[union-attr] for certificate in provider_certificates: - if not certificate.get("revoked", False): - certificates[relation.app.name].append( # type: ignore[union-attr] - { - "csr": certificate["certificate_signing_request"], - "certificate": certificate["certificate"], - } - ) - + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=certificate.get("revoked", False), + ) + certificates.append(provider_certificate) return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: @@ -1313,124 +1327,77 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: return if not self.model.unit.is_leader(): return - requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = self._load_app_relation_data(event.relation) - if not self._relation_data_is_valid(requirer_relation_data): + if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): logger.debug("Relation data did not pass JSON Schema validation") return - provider_certificates = provider_relation_data.get("certificates", []) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) + requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) provider_csrs = [ - certificate_creation_request["certificate_signing_request"] + certificate_creation_request.csr for certificate_creation_request in provider_certificates ] - requirer_unit_certificate_requests = [ - { - "csr": certificate_creation_request["certificate_signing_request"], - "is_ca": certificate_creation_request.get("ca", False), - } - for certificate_creation_request in requirer_csrs - ] - for certificate_request in requirer_unit_certificate_requests: - if certificate_request["csr"] not in provider_csrs: + for certificate_request in requirer_csrs: + if certificate_request.csr not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_request["csr"], - relation_id=event.relation.id, - is_ca=certificate_request["is_ca"], + certificate_signing_request=certificate_request.csr, + relation_id=certificate_request.relation_id, + is_ca=certificate_request.is_ca, ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: """Revokes certificates for which no unit has a CSR. - Goes through all generated certificates and compare against the list of CSRs for all units - of a given relationship. - - Args: - relation_id (int): Relation id + Goes through all generated certificates and compare against the list of CSRs for all units. Returns: None """ - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id - ) - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = self._load_app_relation_data(certificates_relation) - list_of_csrs: List[str] = [] - for unit in certificates_relation.units: - requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) - list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) - provider_certificates = provider_relation_data.get("certificates", []) + provider_certificates = self.get_provider_certificates(relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id) + list_of_csrs = [csr.csr for csr in requirer_csrs] for certificate in provider_certificates: - if certificate["certificate_signing_request"] not in list_of_csrs: + if certificate.csr not in list_of_csrs: self.on.certificate_revocation_request.emit( - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) - self.remove_certificate(certificate=certificate["certificate"]) + self.remove_certificate(certificate=certificate.certificate) def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + ) -> List[RequirerCSR]: """Returns CSR's for which no certificate has been issued. - Example return: [ - { - "relation_id": 0, - "application_name": "tls-certificates-requirer", - "unit_name": "tls-certificates-requirer/0", - "unit_csrs": [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "is_ca": false - } - ] - } - ] - Args: relation_id (int): Relation id Returns: - list: List of dictionaries that contain the unit's csrs - that don't have a certificate issued. + list: List of RequirerCSR objects. """ - all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) - filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - for unit_csr_mapping in all_unit_csr_mappings: - csrs_without_certs = [] - for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] - if not self.certificate_issued_for_csr( - app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] - csr=csr["certificate_signing_request"], # type: ignore[index] - relation_id=relation_id, - ): - csrs_without_certs.append(csr) - if csrs_without_certs: - unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] - filtered_all_unit_csr_mappings.append(unit_csr_mapping) - return filtered_all_unit_csr_mappings - - def get_requirer_csrs( - self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Returns a list of requirers' CSRs grouped by unit. + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self.certificate_issued_for_csr( + app_name=relation_csr.application_name, + csr=relation_csr.csr, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Returns a list of requirers' CSRs. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. Returns: - list: List of dictionaries that contain the unit's csrs - with the following information - relation_id, application_name and unit_name. + list: List[RequirerCSR] """ - unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - + relation_csrs: List[RequirerCSR] = [] relations = ( [ relation @@ -1445,15 +1412,24 @@ def get_requirer_csrs( for unit in relation.units: requirer_relation_data = _load_relation_data(relation.data[unit]) unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - unit_csr_mappings.append( - { - "relation_id": relation.id, - "application_name": relation.app.name, # type: ignore[union-attr] - "unit_name": unit.name, - "unit_csrs": unit_csrs_list, - } - ) - return unit_csr_mappings + for unit_csr in unit_csrs_list: + csr = unit_csr.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = unit_csr.get("ca", False) + if not relation.app: + logger.warning("No remote app in relation - Skipping") + continue + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=relation.app.name, + unit_name=unit.name, + csr=csr, + is_ca=ca, + ) + relation_csrs.append(relation_csr) + return relation_csrs def certificate_issued_for_csr( self, app_name: str, csr: str, relation_id: Optional[int] @@ -1464,19 +1440,18 @@ def certificate_issued_for_csr( app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. relation_id (Optional[int]): Relation ID + Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ - app_name - ] - for issued_pair in issued_certificates_per_csr: - if "csr" in issued_pair and issued_pair["csr"] == csr: - return csr_matches_certificate(csr, issued_pair["certificate"]) + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.csr == csr and issued_certificate.application_name == app_name: + return csr_matches_certificate(csr, issued_certificate.certificate) return False -class TLSCertificatesRequiresV2(Object): +class TLSCertificatesRequiresV3(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" on = CertificatesRequirerCharmEvents() @@ -1505,32 +1480,39 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_broken, self._on_relation_broken ) - if JujuVersion.from_environ().has_secrets: - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - else: - self.framework.observe(charm.on.update_status, self._on_update_status) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - @property - def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + def get_requirer_csrs(self) -> List[RequirerCSR]: """Returns list of requirer's CSRs from relation unit data. - Example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + Returns: + list: List of RequirerCSR objects. """ + requirer_csrs = [] relation = self.model.get_relation(self.relationship_name) if not relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - return requirer_relation_data.get("certificate_signing_requests", []) + requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) + for requirer_csr_dict in requirer_csrs_dict: + csr = requirer_csr_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = requirer_csr_dict.get("ca", False) + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=self.model.app.name, + unit_name=self.model.unit.name, + csr=csr, + is_ca=ca, + ) + requirer_csrs.append(relation_csr) + return requirer_csrs - @property - def _provider_certificates(self) -> List[Dict[str, str]]: + def get_provider_certificates(self) -> List[ProviderCertificate]: """Returns list of certificates from the provider's relation data.""" + provider_certificates: List[ProviderCertificate] = [] relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1539,12 +1521,32 @@ def _provider_certificates(self) -> List[Dict[str, str]]: logger.debug("No remote app in relation: %s", self.relationship_name) return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning("Provider relation data did not pass JSON Schema validation") - return [] - return provider_relation_data.get("certificates", []) + provider_certificate_dicts = provider_relation_data.get("certificates", []) + for provider_certificate_dict in provider_certificate_dicts: + certificate = provider_certificate_dict.get("certificate") + if not certificate: + logger.warning("No certificate found in relation data - Skipping") + continue + ca = provider_certificate_dict.get("ca") + chain = provider_certificate_dict.get("chain", []) + csr = provider_certificate_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + revoked = provider_certificate_dict.get("revoked", False) + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=csr, + certificate=certificate, + ca=ca, + chain=chain, + revoked=revoked, + ) + provider_certificates.append(provider_certificate) + return provider_certificates - def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: + def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: """Adds CSR to relation data. Args: @@ -1560,18 +1562,23 @@ def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict: Dict[str, Union[bool, str]] = { + for requirer_csr in self.get_requirer_csrs(): + if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: + logger.info("CSR already in relation data - Doing nothing") + return + new_csr_dict = { "certificate_signing_request": csr, "ca": is_ca, } - if new_csr_dict in self._requirer_csrs: - logger.info("CSR already in relation data - Doing nothing") - return - requirer_csrs = copy.deepcopy(self._requirer_csrs) - requirer_csrs.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + new_relation_data.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def _remove_requirer_csr(self, csr: str) -> None: + def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: """Removes CSR from relation data. Args: @@ -1586,14 +1593,18 @@ def _remove_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - requirer_csrs = copy.deepcopy(self._requirer_csrs) - if not requirer_csrs: + if not self.get_requirer_csrs(): logger.info("No CSRs in relation data - Doing nothing") return - for requirer_csr in requirer_csrs: + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + for requirer_csr in new_relation_data: if requirer_csr["certificate_signing_request"] == csr: - requirer_csrs.remove(requirer_csr) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + new_relation_data.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) def request_certificate_creation( self, certificate_signing_request: bytes, is_ca: bool = False @@ -1613,7 +1624,9 @@ def request_certificate_creation( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) + self._add_requirer_csr_to_relation_data( + certificate_signing_request.decode().strip(), is_ca=is_ca + ) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1629,7 +1642,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> Returns: None """ - self._remove_requirer_csr(certificate_signing_request.decode().strip()) + self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( @@ -1657,107 +1670,62 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") - def get_assigned_certificates(self) -> List[Dict[str, str]]: + def get_assigned_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - final_list.append(cert) - return final_list - - def get_expiring_certificates(self) -> List[Dict[str, str]]: + assigned_certificates = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + assigned_certificates.append(cert) + return assigned_certificates + + def get_expiring_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit that are expiring or expired. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - expiry_time = _get_certificate_expiry_time(cert["certificate"]) + expiring_certificates: List[ProviderCertificate] = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + expiry_time = _get_certificate_expiry_time(cert.certificate) if not expiry_time: continue expiry_notification_time = expiry_time - timedelta( hours=self.expiry_notification_time ) if datetime.utcnow() > expiry_notification_time: - final_list.append(cert) - return final_list + expiring_certificates.append(cert) + return expiring_certificates def get_certificate_signing_requests( self, fulfilled_only: bool = False, unfulfilled_only: bool = False, - ) -> List[Dict[str, Union[bool, str]]]: + ) -> List[RequirerCSR]: """Gets the list of CSR's that were sent to the provider. You can choose to get only the CSR's that have a certificate assigned or only the CSR's - that don't. + that don't. Args: fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. unfulfilled_only (bool): This option will discard CSRs that have certificates signed. Returns: - List of CSR dictionaries. For example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + List of RequirerCSR objects. """ - final_list = [] - for csr in self._requirer_csrs: - assert isinstance(csr["certificate_signing_request"], str) - cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + csrs = [] + for requirer_csr in self.get_requirer_csrs(): + cert = self._find_certificate_in_relation_data(requirer_csr.csr) if (unfulfilled_only and cert) or (fulfilled_only and not cert): continue - final_list.append(csr) - - return final_list + csrs.append(requirer_csr) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Checks whether relation data is valid based on json schema. - - Args: - certificates_data: Certificate data in dict format. - - Returns: - bool: Whether relation data is valid. - """ - try: - validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False + return csrs def _on_relation_changed(self, event: RelationChangedEvent) -> None: """Handler triggered on relation changed events. @@ -1777,51 +1745,48 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ + if not event.app: + logger.warning("No remote app in relation - Skipping") + return + if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._requirer_csrs + certificate_creation_request.csr + for certificate_creation_request in self.get_requirer_csrs() ] - for certificate in self._provider_certificates: - if certificate["certificate_signing_request"] in requirer_csrs: - if certificate.get("revoked", False): - if JujuVersion.from_environ().has_secrets: - with suppress(SecretNotFoundError): - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.remove_all_revisions() + for certificate in provider_certificates: + if certificate.csr in requirer_csrs: + if certificate.revoked: + with suppress(SecretNotFoundError): + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) else: - if JujuVersion.from_environ().has_secrets: - try: - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.set_content({"certificate": certificate["certificate"]}) - secret.set_info( - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) - except SecretNotFoundError: - secret = self.charm.unit.add_secret( - {"certificate": certificate["certificate"]}, - label=f"{LIBID}-{certificate['certificate_signing_request']}", - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) + try: + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.set_content({"certificate": certificate.certificate}) + secret.set_info( + expire=self._get_next_secret_expiry_time(certificate.certificate), + ) + except SecretNotFoundError: + secret = self.charm.unit.add_secret( + {"certificate": certificate.certificate}, + label=f"{LIBID}-{certificate.csr}", + expire=self._get_next_secret_expiry_time(certificate.certificate), + ) self.on.certificate_available.emit( - certificate_signing_request=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate_signing_request=certificate.csr, + certificate=certificate.certificate, + ca=certificate.ca, + chain=certificate.chain, ) def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: @@ -1877,13 +1842,13 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return csr = event.secret.label[len(f"{LIBID}-") :] - certificate_dict = self._find_certificate_in_relation_data(csr) - if not certificate_dict: + provider_certificate = self._find_certificate_in_relation_data(csr) + if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + expiry_time = _get_certificate_expiry_time(provider_certificate.certificate) if not expiry_time: # A secret expired but matching certificate is invalid. Cleaning up event.secret.remove_all_revisions() @@ -1892,64 +1857,28 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: if datetime.utcnow() < expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], + certificate=provider_certificate.certificate, expiry=expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(certificate_dict["certificate"]), + expire=_get_certificate_expiry_time(provider_certificate.certificate), ) else: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) + self.request_certificate_revocation(provider_certificate.certificate.encode()) event.secret.remove_all_revisions() - def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: + def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: """Returns the certificate that match the given CSR.""" - for certificate_dict in self._provider_certificates: - if certificate_dict["certificate_signing_request"] != csr: + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.csr != csr: continue - return certificate_dict + return provider_certificate return None - - def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Triggered on update status event. - - Goes through each certificate in the "certificates" relation and checks their expiry date. - If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if - they are expired, emits a CertificateExpiredEvent. - - Args: - event (UpdateStatusEvent): Juju event - - Returns: - None - """ - for certificate_dict in self._provider_certificates: - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) - if not expiry_time: - continue - time_difference = expiry_time - datetime.utcnow() - if time_difference.total_seconds() < 0: - logger.warning("Certificate is expired") - self.on.certificate_invalidated.emit( - reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], - ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) - continue - if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], - expiry=expiry_time.isoformat(), - ) diff --git a/src/charm.py b/src/charm.py index 8548964..9568290 100755 --- a/src/charm.py +++ b/src/charm.py @@ -11,10 +11,10 @@ from charms.data_platform_libs.v0.data_interfaces import DatabaseRequires # type: ignore[import] from charms.sdcore_nrf_k8s.v0.fiveg_nrf import NRFRequires # type: ignore[import] -from charms.tls_certificates_interface.v2.tls_certificates import ( # type: ignore[import] +from charms.tls_certificates_interface.v3.tls_certificates import ( # type: ignore[import] CertificateAvailableEvent, CertificateExpiringEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) @@ -56,7 +56,7 @@ def __init__(self, *args): ) self._nrf_requires = NRFRequires(charm=self, relation_name=NRF_RELATION_NAME) self.unit.set_ports(PCF_SBI_PORT) - self._certificates = TLSCertificatesRequiresV2(self, "certificates") + self._certificates = TLSCertificatesRequiresV3(self, "certificates") self.framework.observe(self.on.database_relation_joined, self._configure_sdcore_pcf) self.framework.observe(self.on.database_relation_broken, self._on_database_relation_broken) self.framework.observe(self._database.on.database_created, self._configure_sdcore_pcf) diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index 0b171c5..9eeb6c4 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -507,7 +507,7 @@ def test_given_certificates_are_stored_when_on_certificates_relation_broken_then (root / "support/TLS/pcf.csr").read_text() @patch( - "charms.tls_certificates_interface.v2.tls_certificates.TLSCertificatesRequiresV2.request_certificate_creation", # noqa: E501 + "charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation", # noqa: E501 new=Mock, ) @patch("charm.generate_csr") @@ -527,7 +527,7 @@ def test_given_private_key_exists_when_on_certificates_relation_joined_then_csr_ self.assertEqual((root / "support/TLS/pcf.csr").read_text(), csr.decode()) @patch( - "charms.tls_certificates_interface.v2.tls_certificates.TLSCertificatesRequiresV2.request_certificate_creation", # noqa: E501 + "charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation", # noqa: E501 ) @patch("charm.generate_csr") def test_given_private_key_exists_and_cert_not_yet_requested_when_on_certificates_relation_joined_then_cert_is_requested( # noqa: E501 @@ -548,7 +548,7 @@ def test_given_private_key_exists_and_cert_not_yet_requested_when_on_certificate patch_request_certificate_creation.assert_called_with(certificate_signing_request=csr) @patch( - "charms.tls_certificates_interface.v2.tls_certificates.TLSCertificatesRequiresV2.request_certificate_creation", # noqa: E501 + "charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation", # noqa: E501 ) def test_given_cert_already_stored_when_on_certificates_relation_joined_then_cert_is_not_requested( # noqa: E501 self, patch_request_certificate_creation @@ -605,7 +605,7 @@ def test_given_csr_doesnt_match_stored_one_when_certificate_available_then_certi (root / "support/TLS/pcf.pem").read_text() @patch( - "charms.tls_certificates_interface.v2.tls_certificates.TLSCertificatesRequiresV2.request_certificate_creation", # noqa: E501 + "charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation", # noqa: E501 ) @patch("charm.generate_csr") def test_given_certificate_does_not_match_stored_one_when_certificate_expiring_then_certificate_is_not_requested( # noqa: E501 @@ -626,7 +626,7 @@ def test_given_certificate_does_not_match_stored_one_when_certificate_expiring_t patch_request_certificate_creation.assert_not_called() @patch( - "charms.tls_certificates_interface.v2.tls_certificates.TLSCertificatesRequiresV2.request_certificate_creation", # noqa: E501 + "charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation", # noqa: E501 ) @patch("charm.generate_csr") def test_given_certificate_matches_stored_one_when_certificate_expiring_then_certificate_is_requested( # noqa: E501