Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Update charm libraries #48

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 156 additions & 72 deletions lib/charms/tls_certificates_interface/v2/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 21
LIBPATCH = 22

PYDEPS = ["cryptography", "jsonschema"]

REQUIRER_JSON_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/requirer.json", # noqa: E501
"$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json",
"type": "object",
"title": "`tls_certificates` requirer root schema",
"description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501
Expand Down Expand Up @@ -349,7 +349,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

PROVIDER_JSON_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/provider.json", # noqa: E501
"$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json",
"type": "object",
"title": "`tls_certificates` provider root schema",
"description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501
Expand Down Expand Up @@ -623,6 +623,40 @@ def _load_relation_data(relation_data_content: RelationDataContent) -> dict:
return certificate_data


def _get_closest_future_time(
expiry_notification_time: datetime, expiry_time: datetime
) -> datetime:
"""Return expiry_notification_time if not in the past, otherwise return expiry_time.

Args:
expiry_notification_time (datetime): Notification time of impending expiration
expiry_time (datetime): Expiration time

Returns:
datetime: expiry_notification_time if not in the past, expiry_time otherwise
"""
return (
expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time
)


def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]:
"""Extract expiry time from a certificate string.

Args:
certificate (str): x509 certificate as a string

Returns:
Optional[datetime]: Expiry datetime or None
"""
try:
certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
return certificate_object.not_valid_after
except ValueError:
logger.warning("Could not load certificate.")
return None


def generate_ca(
private_key: bytes,
subject: str,
Expand Down Expand Up @@ -984,6 +1018,38 @@ def generate_csr(
return signed_certificate.public_bytes(serialization.Encoding.PEM)


def csr_matches_certificate(csr: str, cert: str) -> bool:
"""Check if a CSR matches a certificate.

Args:
csr (str): Certificate Signing Request as a string
cert (str): Certificate as a string
Returns:
bool: True/False depending on whether the CSR matches the certificate.
"""
try:
csr_object = x509.load_pem_x509_csr(csr.encode("utf-8"))
cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8"))

if csr_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
) != cert_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
):
return False
if (
csr_object.public_key().public_numbers().n # type: ignore[union-attr]
!= cert_object.public_key().public_numbers().n # type: ignore[union-attr]
):
return False
except ValueError:
logger.warning("Could not load certificate or CSR.")
return False
return True


class CertificatesProviderCharmEvents(CharmEvents):
"""List of events that the TLS Certificates provider charm can leverage."""

Expand Down Expand Up @@ -1447,7 +1513,7 @@ def __init__(

@property
def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]:
"""Returns list of requirer's CSRs from relation data.
"""Returns list of requirer's CSRs from relation unit data.

Example:
[
Expand Down Expand Up @@ -1592,6 +1658,92 @@ def request_certificate_renewal(
)
logger.info("Certificate renewal request completed.")

def get_assigned_certificates(self) -> List[Dict[str, str]]:
"""Get a list of certificates that were assigned to this unit.

Returns:
List of certificates. For example:
[
{
"ca": "-----BEGIN CERTIFICATE-----...",
"chain": [
"-----BEGIN CERTIFICATE-----..."
],
"certificate": "-----BEGIN CERTIFICATE-----...",
"certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...",
}
]
"""
final_list = []
for csr in self.get_certificate_signing_requests(fulfilled_only=True):
assert type(csr["certificate_signing_request"]) == str
if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]):
final_list.append(cert)
return final_list

def get_expiring_certificates(self) -> List[Dict[str, str]]:
"""Get a list of certificates that were assigned to this unit that are expiring or expired.

Returns:
List of certificates. For example:
[
{
"ca": "-----BEGIN CERTIFICATE-----...",
"chain": [
"-----BEGIN CERTIFICATE-----..."
],
"certificate": "-----BEGIN CERTIFICATE-----...",
"certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...",
}
]
"""
final_list = []
for csr in self.get_certificate_signing_requests(fulfilled_only=True):
assert type(csr["certificate_signing_request"]) == str
if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]):
expiry_time = _get_certificate_expiry_time(cert["certificate"])
if not expiry_time:
continue
expiry_notification_time = expiry_time - timedelta(
hours=self.expiry_notification_time
)
if datetime.utcnow() > expiry_notification_time:
final_list.append(cert)
return final_list

def get_certificate_signing_requests(
self,
fulfilled_only: bool = False,
unfulfilled_only: bool = False,
) -> List[Dict[str, Union[bool, str]]]:
"""Gets the list of CSR's that were sent to the provider.

You can choose to get only the CSR's that have a certificate assigned or only the CSR's
that don't.

Args:
fulfilled_only (bool): This option will discard CSRs that don't have certificates yet.
unfulfilled_only (bool): This option will discard CSRs that have certificates signed.
Returns:
List of CSR dictionaries. For example:
[
{
"certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...",
"ca": false
}
]
"""

final_list = []
for csr in self._requirer_csrs:
assert type(csr["certificate_signing_request"]) == str
cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"])
if (unfulfilled_only and cert) or (fulfilled_only and not cert):
continue
final_list.append(csr)

return final_list

@staticmethod
def _relation_data_is_valid(certificates_data: dict) -> bool:
"""Checks whether relation data is valid based on json schema.
Expand Down Expand Up @@ -1802,71 +1954,3 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None:
certificate=certificate_dict["certificate"],
expiry=expiry_time.isoformat(),
)


def csr_matches_certificate(csr: str, cert: str) -> bool:
"""Check if a CSR matches a certificate.

expects to get the original string representations.

Args:
csr (str): Certificate Signing Request
cert (str): Certificate
Returns:
bool: True/False depending on whether the CSR matches the certificate.
"""
try:
csr_object = x509.load_pem_x509_csr(csr.encode("utf-8"))
cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8"))

if csr_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
) != cert_object.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
):
return False
if (
csr_object.public_key().public_numbers().n # type: ignore[union-attr]
!= cert_object.public_key().public_numbers().n # type: ignore[union-attr]
):
return False
except ValueError:
logger.warning("Could not load certificate or CSR.")
return False
return True


def _get_closest_future_time(
expiry_notification_time: datetime, expiry_time: datetime
) -> datetime:
"""Return expiry_notification_time if not in the past, otherwise return expiry_time.

Args:
expiry_notification_time (datetime): Notification time of impending expiration
expiry_time (datetime): Expiration time

Returns:
datetime: expiry_notification_time if not in the past, expiry_time otherwise
"""
return (
expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time
)


def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]:
"""Extract expiry time from a certificate string.

Args:
certificate (str): x509 certificate as a string

Returns:
Optional[datetime]: Expiry datetime or None
"""
try:
certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
return certificate_object.not_valid_after
except ValueError:
logger.warning("Could not load certificate.")
return None
Loading