From 88ec2ad05811b4f7f8de9586e1096a685e8ef72d Mon Sep 17 00:00:00 2001 From: saltiyazan Date: Fri, 11 Oct 2024 18:09:38 +0400 Subject: [PATCH] chore: Use tls lib V4.0 (#339) --- .../v4/tls_certificates.py | 479 ++++++------------ src/charm.py | 6 +- tests/unit/certificates_helpers.py | 6 +- 3 files changed, 176 insertions(+), 315 deletions(-) diff --git a/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/lib/charms/tls_certificates_interface/v4/tls_certificates.py index a46b909..0de221e 100644 --- a/lib/charms/tls_certificates_interface/v4/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -1,152 +1,19 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. -"""Charm library for managing TLS certificates (V4) - BETA. - -> Warning: This is a beta version of the tls-certificates interface library. -> Use at your own risk. +"""Charm library for managing TLS certificates (V4). This library contains the Requires and Provides classes for handling the tls-certificates interface. Pre-requisites: - Juju >= 3.0 + - cryptography >= 43.0.0 + - pydantic -## Getting Started -From a charm directory, fetch the library using `charmcraft`: - -```shell -charmcraft fetch-lib charms.tls_certificates_interface.v4.tls_certificates -``` - -Add the following libraries to the charm's `requirements.txt` file: -- cryptography >= 42.0.0 -- pydantic >= 2.0.0 - -Add the following section to the charm's `charmcraft.yaml` file: -```yaml -parts: - charm: - build-packages: - - libffi-dev - - libssl-dev - - rustc - - cargo -``` - -### Requirer charm -The requirer charm is the charm requiring certificates from another charm that provides them. - -#### Example - -In the following example, the requiring charm requests a certificate using attributes -from the Juju configuration options. - -```python -from typing import List, Optional, cast - -from ops.charm import ActionEvent, CharmBase -from ops.main import main - -from lib.charms.tls_certificates_interface.v4.tls_certificates import ( - CertificateAvailableEvent, - CertificateRequest, - Mode, - TLSCertificatesRequiresV4, -) - - -class DummyTLSCertificatesRequirerCharm(CharmBase): - def __init__(self, *args): - super().__init__(*args) - certificate_requests = self._get_certificate_requests() - self.certificates = TLSCertificatesRequiresV4( - charm=self, - relationship_name="certificates", - certificate_requests=certificate_requests, - mode=Mode.UNIT, - refresh_events=[self.on.config_changed], - ) - self.framework.observe( - self.certificates.on.certificate_available, self._on_certificate_available - ) - self.framework.observe( - self.on.regenerate_private_key_action, self._on_regenerate_private_key_action - ) - self.framework.observe(self.on.get_certificate_action, self._on_get_certificate_action) - - def _get_certificate_requests(self) -> List[CertificateRequest]: - if not self._get_config_common_name(): - return [] - return [ - CertificateRequest( - common_name=self._get_config_common_name(), - sans_dns=self._get_config_sans_dns(), - organization=self._get_config_organization_name(), - organizational_unit=self._get_config_organization_unit_name(), - email_address=self._get_config_email_address(), - country_name=self._get_config_country_name(), - state_or_province_name=self._get_config_state_or_province_name(), - locality_name=self._get_config_locality_name(), - ) - ] - - def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: - print("Certificate available") - - def _on_regenerate_private_key_action(self, event: ActionEvent) -> None: - self.certificates.regenerate_private_key() - - def _on_get_certificate_action(self, event: ActionEvent) -> None: - certificate, _ = self.certificates.get_assigned_certificate( - certificate_request=self._get_certificate_requests()[0] - ) - if not certificate: - event.fail("Certificate not available") - return - event.set_results( - { - "certificate": str(certificate.certificate), - "ca": str(certificate.ca), - "csr": str(certificate.certificate_signing_request), - } - ) - - def _get_config_common_name(self) -> str: - return cast(str, self.model.config.get("common_name")) - - def _get_config_sans_dns(self) -> List[str]: - config_sans_dns = cast(str, self.model.config.get("sans_dns", "")) - return config_sans_dns.split(",") if config_sans_dns else [] - - def _get_config_organization_name(self) -> Optional[str]: - return cast(str, self.model.config.get("organization_name")) +Learn more on how-to use the TLS Certificates interface library by reading the documentation: +- https://charmhub.io/tls-certificates-interface/ - def _get_config_organization_unit_name(self) -> Optional[str]: - return cast(str, self.model.config.get("organization_unit_name")) - - def _get_config_email_address(self) -> Optional[str]: - return cast(str, self.model.config.get("email_address")) - - def _get_config_country_name(self) -> Optional[str]: - return cast(str, self.model.config.get("country_name")) - - def _get_config_state_or_province_name(self) -> Optional[str]: - return cast(str, self.model.config.get("state_or_province_name")) - - def _get_config_locality_name(self) -> Optional[str]: - return cast(str, self.model.config.get("locality_name")) - - -if __name__ == "__main__": - main(DummyTLSCertificatesRequirerCharm) -``` - -You can integrate both charms by running: - -```bash -juju integrate -``` """ # noqa: D214, D405, D411, D416 import copy @@ -185,7 +52,7 @@ def _get_config_locality_name(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 5 +LIBPATCH = 0 PYDEPS = ["cryptography", "pydantic"] @@ -276,7 +143,6 @@ class _Certificate(BaseModel): certificate_signing_request: str certificate: str chain: Optional[List[str]] = None - recommended_expiry_notification_time: Optional[int] = None revoked: Optional[bool] = None def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": @@ -291,7 +157,6 @@ def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": chain=[Certificate.from_string(certificate) for certificate in self.chain] if self.chain else [], - recommended_expiry_notification_time=self.recommended_expiry_notification_time, revoked=self.revoked, ) @@ -353,17 +218,18 @@ class Certificate: raw: str common_name: str - sans_dns: Optional[FrozenSet[str]] = None - sans_ip: Optional[FrozenSet[str]] = None - sans_oid: Optional[FrozenSet[str]] = None + expiry_time: datetime + validity_start_time: datetime + is_ca: bool = False + sans_dns: Optional[FrozenSet[str]] = frozenset() + sans_ip: Optional[FrozenSet[str]] = frozenset() + sans_oid: Optional[FrozenSet[str]] = frozenset() email_address: Optional[str] = None organization: Optional[str] = None organizational_unit: Optional[str] = None country_name: Optional[str] = None state_or_province_name: Optional[str] = None locality_name: Optional[str] = None - expiry_time: Optional[datetime] = None - validity_start_time: Optional[datetime] = None def __str__(self) -> str: """Return the certificate as a string.""" @@ -412,9 +278,18 @@ def from_string(cls, certificate: str) -> "Certificate": sans_oid = [] expiry_time = certificate_object.not_valid_after_utc validity_start_time = certificate_object.not_valid_before_utc + is_ca = False + try: + is_ca = certificate_object.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value.ca # type: ignore[reportAttributeAccessIssue] + except x509.ExtensionNotFound: + pass + return cls( raw=certificate.strip(), common_name=str(common_name[0].value), + is_ca=is_ca, country_name=str(country_name[0].value) if country_name else None, state_or_province_name=str(state_or_province_name[0].value) if state_or_province_name @@ -446,7 +321,6 @@ class CertificateSigningRequest: country_name: Optional[str] = None state_or_province_name: Optional[str] = None locality_name: Optional[str] = None - is_ca: bool = False def __eq__(self, other: object) -> bool: """Check if two CertificateSigningRequest objects are equal.""" @@ -458,22 +332,6 @@ def __str__(self) -> str: """Return the CSR as a string.""" return self.raw - def to_certificate_request(self) -> "CertificateRequest": - """Convert to a CertificateRequest object.""" - return CertificateRequest( - common_name=self.common_name, - sans_dns=self.sans_dns, - sans_ip=self.sans_ip, - sans_oid=self.sans_oid, - email_address=self.email_address, - organization=self.organization, - organizational_unit=self.organizational_unit, - country_name=self.country_name, - state_or_province_name=self.state_or_province_name, - locality_name=self.locality_name, - is_ca=self.is_ca, - ) - @classmethod def from_string(cls, csr: str) -> "CertificateSigningRequest": """Create a CertificateSigningRequest object from a CSR.""" @@ -511,8 +369,8 @@ def from_string(cls, csr: str) -> "CertificateSigningRequest": organization=str(organization_name[0].value) if organization_name else None, email_address=str(email_address[0].value) if email_address else None, sans_dns=sans_dns, - sans_ip=sans_ip if sans_ip else None, - sans_oid=sans_oid if sans_oid else None, + sans_ip=sans_ip, + sans_oid=sans_oid, ) def matches_private_key(self, key: PrivateKey) -> bool: @@ -569,17 +427,17 @@ def get_sha256_hex(self) -> str: @dataclass(frozen=True) -class CertificateRequest: - """This class represents a certificate request. +class CertificateRequestAttributes: + """A representation of the certificate request attributes. This class should be used inside the requirer charm to specify the requested attributes for the certificate. """ common_name: str - sans_dns: Optional[FrozenSet[str]] = None - sans_ip: Optional[FrozenSet[str]] = None - sans_oid: Optional[FrozenSet[str]] = None + sans_dns: Optional[FrozenSet[str]] = frozenset() + sans_ip: Optional[FrozenSet[str]] = frozenset() + sans_oid: Optional[FrozenSet[str]] = frozenset() email_address: Optional[str] = None organization: Optional[str] = None organizational_unit: Optional[str] = None @@ -620,6 +478,23 @@ def generate_csr( locality_name=self.locality_name, ) + @classmethod + def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool): + """Create a CertificateRequestAttributes object from a CSR.""" + return cls( + common_name=csr.common_name, + sans_dns=csr.sans_dns, + sans_ip=csr.sans_ip, + sans_oid=csr.sans_oid, + email_address=csr.email_address, + organization=csr.organization, + organizational_unit=csr.organizational_unit, + country_name=csr.country_name, + state_or_province_name=csr.state_or_province_name, + locality_name=csr.locality_name, + is_ca=is_ca, + ) + @dataclass(frozen=True) class ProviderCertificate: @@ -630,7 +505,6 @@ class ProviderCertificate: certificate_signing_request: CertificateSigningRequest ca: Certificate chain: List[Certificate] - recommended_expiry_notification_time: Optional[int] = None revoked: Optional[bool] = None def to_json(self) -> str: @@ -651,11 +525,12 @@ def to_json(self) -> str: @dataclass(frozen=True) -class RequirerCSR: - """This class represents a certificate signing request requested by the TLS requirer.""" +class RequirerCertificateRequest: + """This class represents a certificate signing request requested by a specific TLS requirer.""" relation_id: int certificate_signing_request: CertificateSigningRequest + is_ca: bool class CertificateAvailableEvent(EventBase): @@ -699,57 +574,6 @@ def chain_as_pem(self) -> str: return "\n\n".join([str(cert) for cert in self.chain]) -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time - if datetime.now(timezone.utc) < expiry_notification_time - else expiry_time - ) - - -def calculate_expiry_notification_time( - validity_start_time: datetime, - expiry_time: datetime, - provider_recommended_notification_time: Optional[int], -) -> datetime: - """Calculate a reasonable time to notify the user about the certificate expiry. - - It takes into account the time recommended by the provider. - Time recommended by the provider is preferred, - then dynamically calculated time. - - Args: - validity_start_time: Certificate validity time - expiry_time: Certificate expiry time - provider_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the provider. - - Returns: - datetime: Time to notify the user about the certificate expiry. - """ - if provider_recommended_notification_time is not None: - provider_recommended_notification_time = abs(provider_recommended_notification_time) - provider_recommendation_time_delta = expiry_time - timedelta( - hours=provider_recommended_notification_time - ) - if validity_start_time < provider_recommendation_time_delta: - return provider_recommendation_time_delta - calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) - return expiry_time - timedelta(hours=calculated_hours) - - def generate_private_key( key_size: int = 2048, public_exponent: int = 65537, @@ -778,9 +602,9 @@ def generate_private_key( def generate_csr( # noqa: C901 private_key: PrivateKey, common_name: str, - sans_dns: Optional[FrozenSet[str]] = None, - sans_ip: Optional[FrozenSet[str]] = None, - sans_oid: Optional[FrozenSet[str]] = None, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), organization: Optional[str] = None, organizational_unit: Optional[str] = None, email_address: Optional[str] = None, @@ -851,11 +675,11 @@ def generate_csr( # noqa: C901 def generate_ca( private_key: PrivateKey, - validity: int, + validity: timedelta, common_name: str, - sans_dns: Optional[FrozenSet[str]] = None, - sans_ip: Optional[FrozenSet[str]] = None, - sans_oid: Optional[FrozenSet[str]] = None, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), organization: Optional[str] = None, organizational_unit: Optional[str] = None, email_address: Optional[str] = None, @@ -867,7 +691,7 @@ def generate_ca( Args: private_key (PrivateKey): Private key - validity (int): Certificate validity time (in days) + validity (timedelta): Certificate validity time common_name (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). sans_dns (FrozenSet[str]): DNS Subject Alternative Names sans_ip (FrozenSet[str]): IP Subject Alternative Names @@ -934,7 +758,7 @@ def generate_ca( .public_key(private_key_object.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + .not_valid_after(datetime.now(timezone.utc) + validity) .add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) .add_extension( @@ -960,7 +784,7 @@ def generate_certificate( csr: CertificateSigningRequest, ca: Certificate, ca_private_key: PrivateKey, - validity: int, + validity: timedelta, is_ca: bool = False, ) -> Certificate: """Generate a TLS certificate based on a CSR. @@ -969,7 +793,7 @@ def generate_certificate( csr (CertificateSigningRequest): CSR ca (Certificate): CA Certificate ca_private_key (PrivateKey): CA private key - validity (int): Certificate validity (in days) + validity (timedelta): Certificate validity time is_ca (bool): Whether the certificate is a CA certificate Returns: @@ -988,7 +812,7 @@ def generate_certificate( .public_key(csr_object.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + .not_valid_after(datetime.now(timezone.utc) + validity) ) extensions = _get_certificate_request_extensions( authority_key_identifier=ca_pem.extensions.get_extension_for_class( @@ -1120,7 +944,7 @@ def __init__( self, charm: CharmBase, relationship_name: str, - certificate_requests: List[CertificateRequest], + certificate_requests: List[CertificateRequestAttributes], mode: Mode = Mode.UNIT, refresh_events: List[BoundEvent] = [], ): @@ -1129,7 +953,8 @@ def __init__( Args: charm (CharmBase): The charm instance to relate to. relationship_name (str): The name of the relation that provides the certificates. - certificate_requests (List[CertificateRequest]): A list of certificate requests. + certificate_requests (List[CertificateRequestAttributes]): + A list with the attributes of the certificate requests. mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. refresh_events (List[BoundEvent]): A list of events to trigger a refresh of the certificates. @@ -1181,12 +1006,28 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: try: csr_str = event.secret.get_content(refresh=True)["csr"] except ModelError: - logger.error("Failed to get CSR from secret - Skipping renewal") + logger.error("Failed to get CSR from secret - Skipping") return csr = CertificateSigningRequest.from_string(csr_str) self._renew_certificate_request(csr) event.secret.remove_all_revisions() + def renew_certificate(self, certificate: ProviderCertificate) -> None: + """Request the renewal of the provided certificate.""" + certificate_signing_request = certificate.certificate_signing_request + secret_label = self._get_csr_secret_label(certificate_signing_request) + try: + secret = self.model.get_secret(label=secret_label) + except SecretNotFoundError: + logger.warning("No matching secret found - Skipping renewal") + return + current_csr = secret.get_content(refresh=True).get("csr", "") + if current_csr != str(certificate_signing_request): + logger.warning("No matching CSR found - Skipping renewal") + return + self._renew_certificate_request(certificate_signing_request) + secret.remove_all_revisions() + def _renew_certificate_request(self, csr: CertificateSigningRequest): """Remove existing CSR from relation data and create a new one.""" self._remove_requirer_csr_from_relation_data(csr) @@ -1269,31 +1110,40 @@ def _private_key_generated(self) -> bool: return False return True - def _csr_matches_certificate_request(self, csr: CertificateSigningRequest) -> bool: + def _csr_matches_certificate_request( + self, certificate_signing_request: CertificateSigningRequest, is_ca: bool + ) -> bool: for certificate_request in self.certificate_requests: - if csr.to_certificate_request() == certificate_request: + if certificate_request == CertificateRequestAttributes.from_csr( + certificate_signing_request, + is_ca, + ): return True return False - def _certificate_requested(self, certificate_request: CertificateRequest) -> bool: + def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool: if not self.private_key: return False csr = self._certificate_requested_for_attributes(certificate_request) if not csr: return False - if not csr.matches_private_key(key=self.private_key): + if not csr.certificate_signing_request.matches_private_key(key=self.private_key): return False return True def _certificate_requested_for_attributes( - self, certificate_request: CertificateRequest - ) -> Optional[CertificateSigningRequest]: + self, + certificate_request: CertificateRequestAttributes, + ) -> Optional[RequirerCertificateRequest]: for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if requirer_csr.to_certificate_request() == certificate_request: + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): return requirer_csr return None - def get_csrs_from_requirer_relation_data(self) -> List[CertificateSigningRequest]: + def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]: """Return list of requirer's CSRs from relation data.""" if self.mode == Mode.APP and not self.model.unit.is_leader(): logger.debug("Not a leader unit - Skipping") @@ -1308,10 +1158,18 @@ def get_csrs_from_requirer_relation_data(self) -> List[CertificateSigningRequest except DataValidationError: logger.warning("Invalid relation data") return [] - return [ - CertificateSigningRequest.from_string(csr.certificate_signing_request) - for csr in requirer_relation_data.certificate_signing_requests - ] + requirer_csrs = [] + for csr in requirer_relation_data.certificate_signing_requests: + requirer_csrs.append( + RequirerCertificateRequest( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + ) + return requirer_csrs def get_provider_certificates(self) -> List[ProviderCertificate]: """Return list of certificates from the provider's relation data.""" @@ -1379,11 +1237,14 @@ def _send_certificate_requests(self): self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) def get_assigned_certificate( - self, certificate_request: CertificateRequest + self, certificate_request: CertificateRequestAttributes ) -> Tuple[ProviderCertificate | None, PrivateKey | None]: """Get the certificate that was assigned to the given certificate request.""" for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == requirer_csr.to_certificate_request(): + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): return self._find_certificate_in_relation_data(requirer_csr), self.private_key return None, None @@ -1396,11 +1257,14 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK return assigned_certificates, self.private_key def _find_certificate_in_relation_data( - self, csr: CertificateSigningRequest + self, csr: RequirerCertificateRequest ) -> Optional[ProviderCertificate]: """Return the certificate that match the given CSR.""" for provider_certificate in self.get_provider_certificates(): - if provider_certificate.certificate_signing_request == csr: + if ( + provider_certificate.certificate_signing_request == csr.certificate_signing_request + and provider_certificate.certificate.is_ca == csr.is_ca + ): return provider_certificate return None @@ -1412,9 +1276,10 @@ def _find_available_certificates(self): If a certificate is revoked, the secret will be removed and an event will be emitted. """ requirer_csrs = self.get_csrs_from_requirer_relation_data() + csrs = [csr.certificate_signing_request for csr in requirer_csrs] provider_certificates = self.get_provider_certificates() for provider_certificate in provider_certificates: - if provider_certificate.certificate_signing_request in requirer_csrs: + if provider_certificate.certificate_signing_request in csrs: secret_label = self._get_csr_secret_label( provider_certificate.certificate_signing_request ) @@ -1428,7 +1293,8 @@ def _find_available_certificates(self): secret.remove_all_revisions() else: if not self._csr_matches_certificate_request( - provider_certificate.certificate_signing_request + certificate_signing_request=provider_certificate.certificate_signing_request, + is_ca=provider_certificate.certificate.is_ca, ): logger.debug("Certificate requested for different attributes - Skipping") continue @@ -1442,7 +1308,7 @@ def _find_available_certificates(self): } ) secret.set_info( - expire=self._get_next_secret_expiry_time(provider_certificate), + expire=provider_certificate.certificate.expiry_time, ) except SecretNotFoundError: logger.debug("Creating new secret with label %s", secret_label) @@ -1452,7 +1318,7 @@ def _find_available_certificates(self): "csr": str(provider_certificate.certificate_signing_request), }, label=secret_label, - expire=self._get_next_secret_expiry_time(provider_certificate), + expire=provider_certificate.certificate.expiry_time, ) self.on.certificate_available.emit( certificate_signing_request=provider_certificate.certificate_signing_request, @@ -1470,54 +1336,29 @@ def _cleanup_certificate_requests(self): - The CSR public key does not match the private key. """ for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if not self._csr_matches_certificate_request(requirer_csr): - self._remove_requirer_csr_from_relation_data(requirer_csr) + if not self._csr_matches_certificate_request( + certificate_signing_request=requirer_csr.certificate_signing_request, + is_ca=requirer_csr.is_ca, + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) logger.info( - "Removed CSR from relation data because \ - it did not match any certificate request" + "Removed CSR from relation data because it did not match any certificate request" # noqa: E501 + ) + elif ( + self.private_key + and not requirer_csr.certificate_signing_request.matches_private_key( + self.private_key + ) + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request ) - elif self.private_key and not requirer_csr.matches_private_key(self.private_key): - self._remove_requirer_csr_from_relation_data(requirer_csr) logger.info( - "Removed CSR from relation data because \ - it did not match the private key" + "Removed CSR from relation data because it did not match the private key" # noqa: E501 ) - def _get_next_secret_expiry_time( - self, provider_certificate: ProviderCertificate - ) -> Optional[datetime]: - """Return the expiry time or expiry notification time. - - Extracts the expiry time from the provided certificate, calculates the - expiry notification time and return the closest of the two, that is in - the future. - - Args: - provider_certificate: ProviderCertificate object - - Returns: - Optional[datetime]: None if the certificate expiry time cannot be read, - next expiry time otherwise. - """ - if not provider_certificate.certificate.expiry_time: - logger.warning("Certificate has no expiry time") - return None - if not provider_certificate.certificate.validity_start_time: - logger.warning("Certificate has no validity start time") - return None - expiry_notification_time = calculate_expiry_notification_time( - validity_start_time=provider_certificate.certificate.validity_start_time, - expiry_time=provider_certificate.certificate.expiry_time, - provider_recommended_notification_time=provider_certificate.recommended_expiry_notification_time, - ) - if not expiry_notification_time: - logger.warning("Could not calculate expiry notification time") - return None - return _get_closest_future_time( - expiry_notification_time, - provider_certificate.certificate.expiry_time, - ) - def _tls_relation_created(self) -> bool: relation = self.model.get_relation(self.relationship_name) if not relation: @@ -1593,10 +1434,12 @@ def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation else self.model.relations.get(self.relationship_name, []) ) - def get_certificate_requests(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + def get_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: """Load certificate requests from the relation data.""" relations = self._get_tls_relations(relation_id) - requirer_csrs: List[RequirerCSR] = [] + requirer_csrs: List[RequirerCertificateRequest] = [] for relation in relations: for unit in relation.units: requirer_csrs.extend(self._load_requirer_databag(relation, unit)) @@ -1605,18 +1448,19 @@ def get_certificate_requests(self, relation_id: Optional[int] = None) -> List[Re def _load_requirer_databag( self, relation: Relation, unit_or_app: Union[Application, Unit] - ) -> List[RequirerCSR]: + ) -> List[RequirerCertificateRequest]: try: requirer_relation_data = _RequirerData.load(relation.data[unit_or_app]) except DataValidationError: logger.debug("Invalid requirer relation data for %s", unit_or_app.name) return [] return [ - RequirerCSR( + RequirerCertificateRequest( relation_id=relation.id, certificate_signing_request=CertificateSigningRequest.from_string( csr.certificate_signing_request ), + is_ca=csr.ca if csr.ca else False, ) for csr in requirer_relation_data.certificate_signing_requests ] @@ -1631,7 +1475,6 @@ def _add_provider_certificate( certificate_signing_request=str(provider_certificate.certificate_signing_request), ca=str(provider_certificate.ca), chain=[str(certificate) for certificate in provider_certificate.chain], - recommended_expiry_notification_time=provider_certificate.recommended_expiry_notification_time, ) provider_certificates = self._load_provider_certificates(relation) if new_certificate in provider_certificates: @@ -1746,19 +1589,35 @@ def get_provider_certificates( certificates.append(certificate.to_provider_certificate(relation_id=relation.id)) return certificates + def get_unsolicited_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return provider certificates for which no certificate requests exists. + + Those certificates should be revoked. + """ + unsolicited_certificates: List[ProviderCertificate] = [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs] + for certificate in provider_certificates: + if certificate.certificate_signing_request not in list_of_csrs: + unsolicited_certificates.append(certificate) + return unsolicited_certificates + def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[RequirerCSR]: + ) -> List[RequirerCertificateRequest]: """Return CSR's for which no certificate has been issued. Args: relation_id (int): Relation id Returns: - list: List of RequirerCSR objects. + list: List of RequirerCertificateRequest objects. """ requirer_csrs = self.get_certificate_requests(relation_id=relation_id) - outstanding_csrs: List[RequirerCSR] = [] + outstanding_csrs: List[RequirerCertificateRequest] = [] for relation_csr in requirer_csrs: if not self._certificate_issued_for_csr( csr=relation_csr.certificate_signing_request, diff --git a/src/charm.py b/src/charm.py index f950b69..3eaf90d 100755 --- a/src/charm.py +++ b/src/charm.py @@ -19,7 +19,7 @@ from charms.sdcore_nrf_k8s.v0.fiveg_nrf import NRFRequires from charms.tls_certificates_interface.v4.tls_certificates import ( Certificate, - CertificateRequest, + CertificateRequestAttributes, Mode, PrivateKey, TLSCertificatesRequiresV4, @@ -296,8 +296,8 @@ def _certificate_is_available(self) -> bool: ) return bool(cert and key) - def _get_certificate_request(self) -> CertificateRequest: - return CertificateRequest( + def _get_certificate_request(self) -> CertificateRequestAttributes: + return CertificateRequestAttributes( common_name=CERTIFICATE_COMMON_NAME, sans_dns=frozenset([CERTIFICATE_COMMON_NAME]), ) diff --git a/tests/unit/certificates_helpers.py b/tests/unit/certificates_helpers.py index d3161fe..bc4b5a4 100644 --- a/tests/unit/certificates_helpers.py +++ b/tests/unit/certificates_helpers.py @@ -1,6 +1,8 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. +from datetime import timedelta + from charms.tls_certificates_interface.v4.tls_certificates import ( PrivateKey, ProviderCertificate, @@ -21,13 +23,13 @@ def example_cert_and_key(relation_id: int) -> tuple[ProviderCertificate, Private ca_certificate = generate_ca( private_key=ca_private_key, common_name="ca.com", - validity=365, + validity=timedelta(days=365), ) certificate = generate_certificate( csr=csr, ca=ca_certificate, ca_private_key=ca_private_key, - validity=365, + validity=timedelta(days=365), ) provider_certificate = ProviderCertificate( relation_id=relation_id,