From 5e22fbc6663c63e219c1ab5f6cea7cb8fdea84b3 Mon Sep 17 00:00:00 2001 From: James Garner Date: Wed, 29 Jan 2025 11:46:15 +1300 Subject: [PATCH 1/4] style: type annotation pass + some light refactoring to simplify types --- lib/charms/operator_libs_linux/v0/apt.py | 166 ++++++++++++++--------- tox.ini | 1 + 2 files changed, 100 insertions(+), 67 deletions(-) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py index 2e463de8..c2313864 100644 --- a/lib/charms/operator_libs_linux/v0/apt.py +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -108,9 +108,10 @@ import os import re import subprocess +import typing from enum import Enum from subprocess import PIPE, CalledProcessError, check_output -from typing import Any, Iterable, Iterator, Mapping +from typing import Any, Iterable, Iterator, Literal, Mapping from urllib.parse import urlparse logger = logging.getLogger(__name__) @@ -197,7 +198,7 @@ def __init__( self._state = state self._version = Version(version, epoch) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """Equality for comparison. Args: @@ -232,7 +233,7 @@ def __str__(self): @staticmethod def _apt( command: str, - package_names: str | list, + package_names: str | list[str], optargs: list[str] | None = None, ) -> None: """Wrap package management commands for Debian/Ubuntu systems. @@ -346,8 +347,10 @@ def fullversion(self) -> str: def _get_epoch_from_version(version: str) -> tuple[str, str]: """Pull the epoch, if any, out of a version string.""" epoch_matcher = re.compile(r"^((?P\d+):)?(?P.*)") - matches = epoch_matcher.search(version).groupdict() - return matches.get("epoch", ""), matches.get("version") + result = epoch_matcher.search(version) + assert result is not None + matches = result.groupdict() + return matches.get("epoch", ""), matches["version"] @classmethod def from_system( @@ -422,32 +425,33 @@ def from_installed_package( ) for line in lines: - try: - matches = dpkg_matcher.search(line).groupdict() - package_status = matches["package_status"] - - if not package_status.endswith("i"): - logger.debug( - "package '%s' in dpkg output but not installed, status: '%s'", - package, - package_status, - ) - break - - epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) - pkg = DebianPackage( - matches["package_name"], - split_version, - epoch, - matches["arch"], - PackageState.Present, - ) - if (pkg.arch == "all" or pkg.arch == arch) and ( - version == "" or str(pkg.version) == version - ): - return pkg - except AttributeError: + result = dpkg_matcher.search(line) + if result is None: logger.warning("dpkg matcher could not parse line: %s", line) + continue + matches = result.groupdict() + package_status = matches["package_status"] + + if not package_status.endswith("i"): + logger.debug( + "package '%s' in dpkg output but not installed, status: '%s'", + package, + package_status, + ) + break + + epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) + pkg = DebianPackage( + name=matches["package_name"], + version=split_version, + epoch=epoch, + arch=matches["arch"], + state=PackageState.Present, + ) + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg # If we didn't find it, fail through raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch)) @@ -486,7 +490,7 @@ def from_apt_cache( for pkg_raw in pkg_groups: lines = str(pkg_raw).splitlines() - vals = {} + vals: dict[str, str] = {} for line in lines: if line.startswith(keys): items = line.split(":", 1) @@ -496,11 +500,11 @@ def from_apt_cache( epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"]) pkg = DebianPackage( - vals["Package"], - split_version, - epoch, - vals["Architecture"], - PackageState.Available, + name=vals["Package"], + version=split_version, + epoch=epoch, + arch=vals["Architecture"], + state=PackageState.Available, ) if (pkg.arch == "all" or pkg.arch == arch) and ( @@ -555,15 +559,15 @@ def _get_parts(self, version: str) -> tuple[str, str]: upstream, debian = version.rsplit("-", 1) return upstream, debian - def _listify(self, revision: str) -> list[str]: - """Split a revision string into a listself. + def _listify(self, revision: str) -> list[str | int]: + """Split a revision string into a list. This list is comprised of alternating between strings and numbers, padded on either end to always be "str, int, str, int..." and always be of even length. This allows us to trivially implement the comparison algorithm described. """ - result = [] + result: list[str | int] = [] while revision: rev_1, remains = self._get_alphas(revision) rev_2, remains = self._get_digits(remains) @@ -596,7 +600,7 @@ def _get_digits(self, revision: str) -> tuple[int, str]: # string is entirely digits return int(revision), "" - def _dstringcmp(self, a, b): # noqa: C901 + def _dstringcmp(self, a: str, b: str) -> Literal[-1, 0, 1]: """Debian package version string section lexical sort algorithm. The lexical comparison is a comparison of ASCII values modified so @@ -635,7 +639,7 @@ def _dstringcmp(self, a, b): # noqa: C901 return 1 return -1 - def _compare_revision_strings(self, first: str, second: str): # noqa: C901 + def _compare_revision_strings(self, first: str, second: str) -> Literal[-1, 0, 1]: """Compare two debian revision strings.""" if first == second: return 0 @@ -675,7 +679,7 @@ def _compare_revision_strings(self, first: str, second: str): # noqa: C901 return -1 return -1 - def _compare_version(self, other) -> int: + def _compare_version(self, other: Version) -> Literal[-1, 0, 1]: if (self.number, self.epoch) == (other.number, other.epoch): return 0 @@ -698,36 +702,52 @@ def _compare_version(self, other) -> int: return 0 - def __lt__(self, other) -> bool: + def __lt__(self, other: Version) -> bool: """Less than magic method impl.""" return self._compare_version(other) < 0 - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """Equality magic method impl.""" + if not isinstance(other, Version): + return False return self._compare_version(other) == 0 - def __gt__(self, other) -> bool: + def __gt__(self, other: Version) -> bool: """Greater than magic method impl.""" return self._compare_version(other) > 0 - def __le__(self, other) -> bool: + def __le__(self, other: Version) -> bool: """Less than or equal to magic method impl.""" return self.__eq__(other) or self.__lt__(other) - def __ge__(self, other) -> bool: + def __ge__(self, other: Version) -> bool: """Greater than or equal to magic method impl.""" return self.__gt__(other) or self.__eq__(other) - def __ne__(self, other) -> bool: + def __ne__(self, other: object) -> bool: """Not equal to magic method impl.""" return not self.__eq__(other) +@typing.overload +def add_package( + package_names: str, + version: str | None = "", + arch: str | None = "", + update_cache: bool = False, +) -> DebianPackage: ... +@typing.overload +def add_package( + package_names: str | list[str], + version: str | None = "", + arch: str | None = "", + update_cache: bool = False, +) -> DebianPackage | list[DebianPackage]: ... def add_package( package_names: str | list[str], version: str | None = "", arch: str | None = "", - update_cache: bool | None = False, + update_cache: bool = False, ) -> DebianPackage | list[DebianPackage]: """Add a package or list of packages to the system. @@ -748,8 +768,6 @@ def add_package( update() cache_refreshed = True - packages = {"success": [], "retry": [], "failed": []} - package_names = [package_names] if isinstance(package_names, str) else package_names if not package_names: raise TypeError("Expected at least one package name to add, received zero!") @@ -759,36 +777,40 @@ def add_package( "Explicit version should not be set if more than one package is being added!" ) + succeeded: list[DebianPackage] = [] + retry: list[str] = [] + failed: list[str] = [] + for p in package_names: - pkg, success = _add(p, version, arch) - if success: - packages["success"].append(pkg) + pkg, _ = _add(p, version, arch) + if isinstance(pkg, DebianPackage): + succeeded.append(pkg) else: logger.warning("failed to locate and install/update '%s'", pkg) - packages["retry"].append(p) + retry.append(p) - if packages["retry"] and not cache_refreshed: + if retry and not cache_refreshed: logger.info("updating the apt-cache and retrying installation of failed packages.") update() - for p in packages["retry"]: - pkg, success = _add(p, version, arch) - if success: - packages["success"].append(pkg) + for p in retry: + pkg, _ = _add(p, version, arch) + if isinstance(pkg, DebianPackage): + succeeded.append(pkg) else: - packages["failed"].append(p) + failed.append(p) - if packages["failed"]: - raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"]))) + if failed: + raise PackageError(f"Failed to install packages: {', '.join(failed)}") - return packages["success"] if len(packages["success"]) > 1 else packages["success"][0] + return succeeded if len(succeeded) > 1 else succeeded[0] def _add( name: str, version: str | None = "", arch: str | None = "", -) -> tuple[DebianPackage | str, bool]: +) -> tuple[DebianPackage, Literal[True]] | tuple[str, Literal[False]]: """Add a package to the system. Args: @@ -807,6 +829,14 @@ def _add( return name, False +@typing.overload +def remove_package( + package_names: str, +) -> DebianPackage: ... +@typing.overload +def remove_package( + package_names: list[str], +) -> DebianPackage | list[DebianPackage]: ... def remove_package( package_names: str | list[str], ) -> DebianPackage | list[DebianPackage]: @@ -1117,7 +1147,9 @@ def _get_keyid_by_gpg_key(key_material: bytes) -> str: if "gpg: no valid OpenPGP data found." in err: raise GPGKeyError("Invalid GPG key material provided") # from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10) - return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1) + result = re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE) + assert result is not None + return result.group(1) @staticmethod def _get_key_by_keyid(keyid: str) -> str: @@ -1271,7 +1303,7 @@ def __len__(self) -> int: """Return number of repositories in map.""" return len(self._repository_map) - def __iter__(self) -> Iterator[DebianRepository]: + def __iter__(self) -> Iterator[DebianRepository]: # pyright: ignore[reportIncompatibleMethodOverride] """Return iterator for RepositoryMapping. Iterates over the DebianRepository values rather than the string names. diff --git a/tox.ini b/tox.ini index 420a8e98..951e1a80 100644 --- a/tox.ini +++ b/tox.ini @@ -53,6 +53,7 @@ commands = #pyright {[vars]lib_dir} {posargs} # FIXME: enable static analysis of all libs #pyright {[vars]test_dir} {posargs} # FIXME: enable static analysis of tests echo 'NOTE: Only manually specified libs are being static type checked currently' + pyright lib/charms/operator_libs_linux/v0/apt.py {posargs} pyright lib/charms/operator_libs_linux/v2/snap.py {posargs} [testenv:unit] From 2453b0f093b5d3a7d0aad34f8204841b262cff5f Mon Sep 17 00:00:00 2001 From: James Garner Date: Wed, 29 Jan 2025 12:09:23 +1300 Subject: [PATCH 2/4] fix: satisfy type checker about revision comparison funcs; add FIXMEs --- lib/charms/operator_libs_linux/v0/apt.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py index c2313864..58c29795 100644 --- a/lib/charms/operator_libs_linux/v0/apt.py +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -631,6 +631,8 @@ def _dstringcmp(self, a: str, b: str) -> Literal[-1, 0, 1]: return -1 except IndexError: # a is longer than b but otherwise equal, greater unless there are tildes + # FIXME: type checker thinks "char" is possibly unbound -- I want to refactor this anyway + assert isinstance(char, str) # pyright: ignore[reportPossiblyUnboundVariable] if char == "~": return -1 return 1 @@ -655,27 +657,34 @@ def _compare_revision_strings(self, first: str, second: str) -> Literal[-1, 0, 1 # explicitly raise IndexError if we've fallen off the edge of list2 if i >= len(second_list): raise IndexError + other = second_list[i] # if the items are equal, next - if item == second_list[i]: + if item == other: continue # numeric comparison if isinstance(item, int): - if item > second_list[i]: + assert isinstance(other, int) + if item > other: return 1 - if item < second_list[i]: + if item < other: return -1 else: # string comparison - return self._dstringcmp(item, second_list[i]) + assert isinstance(other, str) + return self._dstringcmp(item, other) except IndexError: # rev1 is longer than rev2 but otherwise equal, hence greater # ...except for goddamn tildes - if first_list[len(second_list)][0][0] == "~": + # FIXME: bug?? we return 1 in both cases + # FIXME: first_list[len(second_list)] should be a string, why are we indexing to 0 twice? + if first_list[len(second_list)][0][0] == "~": # type: ignore return 1 return 1 # rev1 is shorter than rev2 but otherwise equal, hence lesser # ...except for goddamn tildes - if second_list[len(first_list)][0][0] == "~": + # FIXME: bug?? we return -1 in both cases + # FIXME: first_list[len(second_list)] should be a string, why are we indexing to 0 twice? + if second_list[len(first_list)][0][0] == "~": # type: ignore return -1 return -1 From 5dedd9949054a342166d97e9ad503a5976f27ea5 Mon Sep 17 00:00:00 2001 From: James Garner Date: Wed, 29 Jan 2025 16:49:35 +1300 Subject: [PATCH 3/4] fix: remove unnecessary assert in one case --- lib/charms/operator_libs_linux/v0/apt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py index 58c29795..f90f0bc7 100644 --- a/lib/charms/operator_libs_linux/v0/apt.py +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -631,9 +631,10 @@ def _dstringcmp(self, a: str, b: str) -> Literal[-1, 0, 1]: return -1 except IndexError: # a is longer than b but otherwise equal, greater unless there are tildes - # FIXME: type checker thinks "char" is possibly unbound -- I want to refactor this anyway - assert isinstance(char, str) # pyright: ignore[reportPossiblyUnboundVariable] - if char == "~": + # FIXME: type checker thinks "char" is possibly unbound as it's a loop variable + # but it won't be since the IndexError can only occur inside the loop + # -- I'd like to refactor away this `try ... except` anyway + if char == "~": # pyright: ignore[reportPossiblyUnboundVariable] return -1 return 1 # if we get here, a is shorter than b but otherwise equal, so check for tildes... From bbbcf4669e1b6afd9b939550c8c0534dd8c07635 Mon Sep 17 00:00:00 2001 From: James Garner Date: Thu, 30 Jan 2025 11:26:22 +1300 Subject: [PATCH 4/4] fix: correct add_package overload --- lib/charms/operator_libs_linux/v0/apt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py index f90f0bc7..3df5f788 100644 --- a/lib/charms/operator_libs_linux/v0/apt.py +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -748,7 +748,7 @@ def add_package( ) -> DebianPackage: ... @typing.overload def add_package( - package_names: str | list[str], + package_names: list[str], version: str | None = "", arch: str | None = "", update_cache: bool = False,