diff --git a/lib/charms/sdcore_nms_k8s/v0/sdcore_config.py b/lib/charms/sdcore_nms_k8s/v0/sdcore_config.py index ebc7816..5411236 100644 --- a/lib/charms/sdcore_nms_k8s/v0/sdcore_config.py +++ b/lib/charms/sdcore_nms_k8s/v0/sdcore_config.py @@ -106,10 +106,10 @@ def _on_sdcore_config_relation_joined(self, event: RelationJoinedEvent): import logging from typing import Optional -from interface_tester.schema_base import DataBagSchema # type: ignore[import] +from interface_tester.schema_base import DataBagSchema from ops.charm import CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent from ops.framework import EventBase, EventSource, Handle, Object -from ops.model import Relation +from ops.model import ModelError, Relation from pydantic import BaseModel, Field, ValidationError # The unique Charmhub library identifier, never change it @@ -120,7 +120,7 @@ def _on_sdcore_config_relation_joined(self, event: RelationJoinedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 1 +LIBPATCH = 3 logger = logging.getLogger(__name__) @@ -150,7 +150,7 @@ class SdcoreConfigProviderAppData(BaseModel): class ProviderSchema(DataBagSchema): """The schema for the provider side of the sdcore-config interface.""" - app: SdcoreConfigProviderAppData + app_data: SdcoreConfigProviderAppData def data_is_valid(data: dict) -> bool: @@ -163,7 +163,7 @@ def data_is_valid(data: dict) -> bool: bool: True if data is valid, False otherwise. """ try: - ProviderSchema(app=data) + ProviderSchema(app_data=SdcoreConfigProviderAppData(**data)) return True except ValidationError as e: logger.error("Invalid data: %s", e) @@ -207,7 +207,7 @@ class SdcoreConfigRequirerCharmEvents(CharmEvents): class SdcoreConfigRequires(Object): """Class to be instantiated by the SD-Core config requirer charm.""" - on = SdcoreConfigRequirerCharmEvents() + on = SdcoreConfigRequirerCharmEvents() # type: ignore def __init__(self, charm: CharmBase, relation_name: str): """Init.""" @@ -336,4 +336,7 @@ def set_webui_url_in_all_relations(self, webui_url: str) -> None: raise RuntimeError(f"Relation {self.relation_name} not created yet.") for relation in relations: - relation.data[self.charm.app].update({"webui_url": webui_url}) + try: + relation.data[self.charm.app].update({"webui_url": webui_url}) + except ModelError as exc: + logger.error("Error updating the relation data: %s", str(exc)) diff --git a/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/lib/charms/tls_certificates_interface/v4/tls_certificates.py index 10ca873..278207d 100644 --- a/lib/charms/tls_certificates_interface/v4/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -32,7 +32,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent +from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent, SecretRemoveEvent from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion from ops.model import ( @@ -52,7 +52,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 1 +LIBPATCH = 5 PYDEPS = ["cryptography", "pydantic"] @@ -305,6 +305,37 @@ def from_string(cls, certificate: str) -> "Certificate": validity_start_time=validity_start_time, ) + def matches_private_key(self, private_key: PrivateKey) -> bool: + """Check if this certificate matches a given private key. + + Args: + private_key (PrivateKey): The private key to validate against. + + Returns: + bool: True if the certificate matches the private key, False otherwise. + """ + try: + cert_object = x509.load_pem_x509_certificate(self.raw.encode()) + key_object = serialization.load_pem_private_key( + private_key.raw.encode(), password=None + ) + + cert_public_key = cert_object.public_key() + key_public_key = key_object.public_key() + + if not isinstance(cert_public_key, rsa.RSAPublicKey): + logger.warning("Certificate does not use RSA public key") + return False + + if not isinstance(key_public_key, rsa.RSAPublicKey): + logger.warning("Private key is not an RSA key") + return False + + return cert_public_key.public_numbers() == key_public_key.public_numbers() + except Exception as e: + logger.warning("Failed to validate certificate and private key match: %s", e) + return False + @dataclass(frozen=True) class CertificateSigningRequest: @@ -974,6 +1005,7 @@ def __init__( self.framework.observe(charm.on[relationship_name].relation_created, self._configure) self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + self.framework.observe(charm.on.secret_remove, self._on_secret_remove) for event in refresh_events: self.framework.observe(event, self._configure) @@ -993,9 +1025,20 @@ def _configure(self, _: EventBase): self._find_available_certificates() self._cleanup_certificate_requests() - def _mode_is_valid(self, mode) -> bool: + def _mode_is_valid(self, mode: Mode) -> bool: return mode in [Mode.UNIT, Mode.APP] + def _on_secret_remove(self, event: SecretRemoveEvent) -> None: + """Handle Secret Removed Event.""" + try: + event.secret.remove_revision(event.revision) + except SecretNotFoundError: + logger.warning( + "No such secret %s, nothing to remove", + event.secret.label or event.secret.id, + ) + return + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: """Handle Secret Expired Event. @@ -1069,7 +1112,7 @@ def _get_app_or_unit(self) -> Union[Application, Unit]: raise TLSCertificatesError("Invalid mode") @property - def private_key(self) -> PrivateKey | None: + def private_key(self) -> Optional[PrivateKey]: """Return the private key.""" if not self._private_key_generated(): return None @@ -1238,7 +1281,7 @@ def _send_certificate_requests(self): def get_assigned_certificate( self, certificate_request: CertificateRequestAttributes - ) -> Tuple[ProviderCertificate | None, PrivateKey | None]: + ) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]: """Get the certificate that was assigned to the given certificate request.""" for requirer_csr in self.get_csrs_from_requirer_relation_data(): if certificate_request == CertificateRequestAttributes.from_csr( @@ -1248,7 +1291,9 @@ def get_assigned_certificate( return self._find_certificate_in_relation_data(requirer_csr), self.private_key return None, None - def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateKey | None]: + def get_assigned_certificates( + self, + ) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]: """Get a list of certificates that were assigned to this or app.""" assigned_certificates = [] for requirer_csr in self.get_csrs_from_requirer_relation_data(): @@ -1259,12 +1304,19 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK def _find_certificate_in_relation_data( self, csr: RequirerCertificateRequest ) -> Optional[ProviderCertificate]: - """Return the certificate that match the given CSR.""" + """Return the certificate that matches the given CSR, validated against the private key.""" + if not self.private_key: + return None for provider_certificate in self.get_provider_certificates(): if ( provider_certificate.certificate_signing_request == csr.certificate_signing_request and provider_certificate.certificate.is_ca == csr.is_ca ): + if not provider_certificate.certificate.matches_private_key(self.private_key): + logger.warning( + "Certificate does not match the private key. Ignoring invalid certificate." + ) + continue return provider_certificate return None