Skip to content

Commit

Permalink
style: update type hints and enable type checking in ci for apt.py (#151
Browse files Browse the repository at this point in the history
)

Enable static analysis of `apt.py` in ci. `LIBPATCH` will be bumped after one more
follow up PR handling `ruff` linting for this lib.
  • Loading branch information
james-garner-canonical authored Jan 29, 2025
1 parent 75f44e3 commit 06633c2
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 74 deletions.
190 changes: 116 additions & 74 deletions lib/charms/operator_libs_linux/v0/apt.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@
import os
import re
import subprocess
import typing
from enum import Enum
from subprocess import PIPE, CalledProcessError, check_output
from typing import Any, Iterable, Iterator, Mapping
from typing import Any, Iterable, Iterator, Literal, Mapping
from urllib.parse import urlparse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -197,7 +198,7 @@ def __init__(
self._state = state
self._version = Version(version, epoch)

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Equality for comparison.
Args:
Expand Down Expand Up @@ -232,7 +233,7 @@ def __str__(self):
@staticmethod
def _apt(
command: str,
package_names: str | list,
package_names: str | list[str],
optargs: list[str] | None = None,
) -> None:
"""Wrap package management commands for Debian/Ubuntu systems.
Expand Down Expand Up @@ -346,8 +347,10 @@ def fullversion(self) -> str:
def _get_epoch_from_version(version: str) -> tuple[str, str]:
"""Pull the epoch, if any, out of a version string."""
epoch_matcher = re.compile(r"^((?P<epoch>\d+):)?(?P<version>.*)")
matches = epoch_matcher.search(version).groupdict()
return matches.get("epoch", ""), matches.get("version")
result = epoch_matcher.search(version)
assert result is not None
matches = result.groupdict()
return matches.get("epoch", ""), matches["version"]

@classmethod
def from_system(
Expand Down Expand Up @@ -422,32 +425,33 @@ def from_installed_package(
)

for line in lines:
try:
matches = dpkg_matcher.search(line).groupdict()
package_status = matches["package_status"]

if not package_status.endswith("i"):
logger.debug(
"package '%s' in dpkg output but not installed, status: '%s'",
package,
package_status,
)
break

epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"])
pkg = DebianPackage(
matches["package_name"],
split_version,
epoch,
matches["arch"],
PackageState.Present,
)
if (pkg.arch == "all" or pkg.arch == arch) and (
version == "" or str(pkg.version) == version
):
return pkg
except AttributeError:
result = dpkg_matcher.search(line)
if result is None:
logger.warning("dpkg matcher could not parse line: %s", line)
continue
matches = result.groupdict()
package_status = matches["package_status"]

if not package_status.endswith("i"):
logger.debug(
"package '%s' in dpkg output but not installed, status: '%s'",
package,
package_status,
)
break

epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"])
pkg = DebianPackage(
name=matches["package_name"],
version=split_version,
epoch=epoch,
arch=matches["arch"],
state=PackageState.Present,
)
if (pkg.arch == "all" or pkg.arch == arch) and (
version == "" or str(pkg.version) == version
):
return pkg

# If we didn't find it, fail through
raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch))
Expand Down Expand Up @@ -486,7 +490,7 @@ def from_apt_cache(

for pkg_raw in pkg_groups:
lines = str(pkg_raw).splitlines()
vals = {}
vals: dict[str, str] = {}
for line in lines:
if line.startswith(keys):
items = line.split(":", 1)
Expand All @@ -496,11 +500,11 @@ def from_apt_cache(

epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"])
pkg = DebianPackage(
vals["Package"],
split_version,
epoch,
vals["Architecture"],
PackageState.Available,
name=vals["Package"],
version=split_version,
epoch=epoch,
arch=vals["Architecture"],
state=PackageState.Available,
)

if (pkg.arch == "all" or pkg.arch == arch) and (
Expand Down Expand Up @@ -555,15 +559,15 @@ def _get_parts(self, version: str) -> tuple[str, str]:
upstream, debian = version.rsplit("-", 1)
return upstream, debian

def _listify(self, revision: str) -> list[str]:
"""Split a revision string into a listself.
def _listify(self, revision: str) -> list[str | int]:
"""Split a revision string into a list.
This list is comprised of alternating between strings and numbers,
padded on either end to always be "str, int, str, int..." and
always be of even length. This allows us to trivially implement the
comparison algorithm described.
"""
result = []
result: list[str | int] = []
while revision:
rev_1, remains = self._get_alphas(revision)
rev_2, remains = self._get_digits(remains)
Expand Down Expand Up @@ -596,7 +600,7 @@ def _get_digits(self, revision: str) -> tuple[int, str]:
# string is entirely digits
return int(revision), ""

def _dstringcmp(self, a, b): # noqa: C901
def _dstringcmp(self, a: str, b: str) -> Literal[-1, 0, 1]:
"""Debian package version string section lexical sort algorithm.
The lexical comparison is a comparison of ASCII values modified so
Expand Down Expand Up @@ -627,15 +631,18 @@ def _dstringcmp(self, a, b): # noqa: C901
return -1
except IndexError:
# a is longer than b but otherwise equal, greater unless there are tildes
if char == "~":
# FIXME: type checker thinks "char" is possibly unbound as it's a loop variable
# but it won't be since the IndexError can only occur inside the loop
# -- I'd like to refactor away this `try ... except` anyway
if char == "~": # pyright: ignore[reportPossiblyUnboundVariable]
return -1
return 1
# if we get here, a is shorter than b but otherwise equal, so check for tildes...
if b[len(a)] == "~":
return 1
return -1

def _compare_revision_strings(self, first: str, second: str): # noqa: C901
def _compare_revision_strings(self, first: str, second: str) -> Literal[-1, 0, 1]:
"""Compare two debian revision strings."""
if first == second:
return 0
Expand All @@ -651,31 +658,38 @@ def _compare_revision_strings(self, first: str, second: str): # noqa: C901
# explicitly raise IndexError if we've fallen off the edge of list2
if i >= len(second_list):
raise IndexError
other = second_list[i]
# if the items are equal, next
if item == second_list[i]:
if item == other:
continue
# numeric comparison
if isinstance(item, int):
if item > second_list[i]:
assert isinstance(other, int)
if item > other:
return 1
if item < second_list[i]:
if item < other:
return -1
else:
# string comparison
return self._dstringcmp(item, second_list[i])
assert isinstance(other, str)
return self._dstringcmp(item, other)
except IndexError:
# rev1 is longer than rev2 but otherwise equal, hence greater
# ...except for goddamn tildes
if first_list[len(second_list)][0][0] == "~":
# FIXME: bug?? we return 1 in both cases
# FIXME: first_list[len(second_list)] should be a string, why are we indexing to 0 twice?
if first_list[len(second_list)][0][0] == "~": # type: ignore
return 1
return 1
# rev1 is shorter than rev2 but otherwise equal, hence lesser
# ...except for goddamn tildes
if second_list[len(first_list)][0][0] == "~":
# FIXME: bug?? we return -1 in both cases
# FIXME: first_list[len(second_list)] should be a string, why are we indexing to 0 twice?
if second_list[len(first_list)][0][0] == "~": # type: ignore
return -1
return -1

def _compare_version(self, other) -> int:
def _compare_version(self, other: Version) -> Literal[-1, 0, 1]:
if (self.number, self.epoch) == (other.number, other.epoch):
return 0

Expand All @@ -698,36 +712,52 @@ def _compare_version(self, other) -> int:

return 0

def __lt__(self, other) -> bool:
def __lt__(self, other: Version) -> bool:
"""Less than magic method impl."""
return self._compare_version(other) < 0

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Equality magic method impl."""
if not isinstance(other, Version):
return False
return self._compare_version(other) == 0

def __gt__(self, other) -> bool:
def __gt__(self, other: Version) -> bool:
"""Greater than magic method impl."""
return self._compare_version(other) > 0

def __le__(self, other) -> bool:
def __le__(self, other: Version) -> bool:
"""Less than or equal to magic method impl."""
return self.__eq__(other) or self.__lt__(other)

def __ge__(self, other) -> bool:
def __ge__(self, other: Version) -> bool:
"""Greater than or equal to magic method impl."""
return self.__gt__(other) or self.__eq__(other)

def __ne__(self, other) -> bool:
def __ne__(self, other: object) -> bool:
"""Not equal to magic method impl."""
return not self.__eq__(other)


@typing.overload
def add_package(
package_names: str,
version: str | None = "",
arch: str | None = "",
update_cache: bool = False,
) -> DebianPackage: ...
@typing.overload
def add_package(
package_names: list[str],
version: str | None = "",
arch: str | None = "",
update_cache: bool = False,
) -> DebianPackage | list[DebianPackage]: ...
def add_package(
package_names: str | list[str],
version: str | None = "",
arch: str | None = "",
update_cache: bool | None = False,
update_cache: bool = False,
) -> DebianPackage | list[DebianPackage]:
"""Add a package or list of packages to the system.
Expand All @@ -748,8 +778,6 @@ def add_package(
update()
cache_refreshed = True

packages = {"success": [], "retry": [], "failed": []}

package_names = [package_names] if isinstance(package_names, str) else package_names
if not package_names:
raise TypeError("Expected at least one package name to add, received zero!")
Expand All @@ -759,36 +787,40 @@ def add_package(
"Explicit version should not be set if more than one package is being added!"
)

succeeded: list[DebianPackage] = []
retry: list[str] = []
failed: list[str] = []

for p in package_names:
pkg, success = _add(p, version, arch)
if success:
packages["success"].append(pkg)
pkg, _ = _add(p, version, arch)
if isinstance(pkg, DebianPackage):
succeeded.append(pkg)
else:
logger.warning("failed to locate and install/update '%s'", pkg)
packages["retry"].append(p)
retry.append(p)

if packages["retry"] and not cache_refreshed:
if retry and not cache_refreshed:
logger.info("updating the apt-cache and retrying installation of failed packages.")
update()

for p in packages["retry"]:
pkg, success = _add(p, version, arch)
if success:
packages["success"].append(pkg)
for p in retry:
pkg, _ = _add(p, version, arch)
if isinstance(pkg, DebianPackage):
succeeded.append(pkg)
else:
packages["failed"].append(p)
failed.append(p)

if packages["failed"]:
raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"])))
if failed:
raise PackageError(f"Failed to install packages: {', '.join(failed)}")

return packages["success"] if len(packages["success"]) > 1 else packages["success"][0]
return succeeded if len(succeeded) > 1 else succeeded[0]


def _add(
name: str,
version: str | None = "",
arch: str | None = "",
) -> tuple[DebianPackage | str, bool]:
) -> tuple[DebianPackage, Literal[True]] | tuple[str, Literal[False]]:
"""Add a package to the system.
Args:
Expand All @@ -807,6 +839,14 @@ def _add(
return name, False


@typing.overload
def remove_package(
package_names: str,
) -> DebianPackage: ...
@typing.overload
def remove_package(
package_names: list[str],
) -> DebianPackage | list[DebianPackage]: ...
def remove_package(
package_names: str | list[str],
) -> DebianPackage | list[DebianPackage]:
Expand Down Expand Up @@ -1117,7 +1157,9 @@ def _get_keyid_by_gpg_key(key_material: bytes) -> str:
if "gpg: no valid OpenPGP data found." in err:
raise GPGKeyError("Invalid GPG key material provided")
# from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10)
return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1)
result = re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE)
assert result is not None
return result.group(1)

@staticmethod
def _get_key_by_keyid(keyid: str) -> str:
Expand Down Expand Up @@ -1271,7 +1313,7 @@ def __len__(self) -> int:
"""Return number of repositories in map."""
return len(self._repository_map)

def __iter__(self) -> Iterator[DebianRepository]:
def __iter__(self) -> Iterator[DebianRepository]: # pyright: ignore[reportIncompatibleMethodOverride]
"""Return iterator for RepositoryMapping.
Iterates over the DebianRepository values rather than the string names.
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ commands =
#pyright {[vars]lib_dir} {posargs} # FIXME: enable static analysis of all libs
#pyright {[vars]test_dir} {posargs} # FIXME: enable static analysis of tests
echo 'NOTE: Only manually specified libs are being static type checked currently'
pyright lib/charms/operator_libs_linux/v0/apt.py {posargs}
pyright lib/charms/operator_libs_linux/v2/snap.py {posargs}

[testenv:unit]
Expand Down

0 comments on commit 06633c2

Please sign in to comment.