Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: update type hints and enable type checking in ci for apt.py #151

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 116 additions & 74 deletions lib/charms/operator_libs_linux/v0/apt.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@
import os
import re
import subprocess
import typing
from enum import Enum
from subprocess import PIPE, CalledProcessError, check_output
from typing import Any, Iterable, Iterator, Mapping
from typing import Any, Iterable, Iterator, Literal, Mapping
from urllib.parse import urlparse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -197,7 +198,7 @@ def __init__(
self._state = state
self._version = Version(version, epoch)

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Equality for comparison.

Args:
Expand Down Expand Up @@ -232,7 +233,7 @@ def __str__(self):
@staticmethod
def _apt(
command: str,
package_names: str | list,
package_names: str | list[str],
optargs: list[str] | None = None,
) -> None:
"""Wrap package management commands for Debian/Ubuntu systems.
Expand Down Expand Up @@ -346,8 +347,10 @@ def fullversion(self) -> str:
def _get_epoch_from_version(version: str) -> tuple[str, str]:
"""Pull the epoch, if any, out of a version string."""
epoch_matcher = re.compile(r"^((?P<epoch>\d+):)?(?P<version>.*)")
matches = epoch_matcher.search(version).groupdict()
return matches.get("epoch", ""), matches.get("version")
result = epoch_matcher.search(version)
assert result is not None
matches = result.groupdict()
return matches.get("epoch", ""), matches["version"]
Comment on lines -349 to +353
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue here is that re.Pattern.search(...) can return None. Presumably it never does on well formed version strings, so we can just assert.

Note also the change to returning matches["version"] instead of matches.get("version"), as the latter is also typed as possibly returning None, which contradicts the function signature. The regex match should always have an entry for "version" since it matches .*.


@classmethod
def from_system(
Expand Down Expand Up @@ -422,32 +425,33 @@ def from_installed_package(
)

for line in lines:
try:
matches = dpkg_matcher.search(line).groupdict()
package_status = matches["package_status"]

if not package_status.endswith("i"):
logger.debug(
"package '%s' in dpkg output but not installed, status: '%s'",
package,
package_status,
)
break

epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"])
pkg = DebianPackage(
matches["package_name"],
split_version,
epoch,
matches["arch"],
PackageState.Present,
)
if (pkg.arch == "all" or pkg.arch == arch) and (
version == "" or str(pkg.version) == version
):
return pkg
except AttributeError:
result = dpkg_matcher.search(line)
if result is None:
logger.warning("dpkg matcher could not parse line: %s", line)
continue
matches = result.groupdict()
package_status = matches["package_status"]

if not package_status.endswith("i"):
logger.debug(
"package '%s' in dpkg output but not installed, status: '%s'",
package,
package_status,
)
break

epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"])
pkg = DebianPackage(
name=matches["package_name"],
version=split_version,
epoch=epoch,
arch=matches["arch"],
state=PackageState.Present,
)
if (pkg.arch == "all" or pkg.arch == arch) and (
version == "" or str(pkg.version) == version
):
return pkg
Comment on lines -425 to +454
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally the whole loop body is wrapped in a try ... except AttributeError, again due to re.Pattern.search(...) potentially returning None instead of a re.Match. This change explicitly checks if the search result is None, allowing the loop body to be dedented.


# If we didn't find it, fail through
raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch))
Expand Down Expand Up @@ -486,7 +490,7 @@ def from_apt_cache(

for pkg_raw in pkg_groups:
lines = str(pkg_raw).splitlines()
vals = {}
vals: dict[str, str] = {}
for line in lines:
if line.startswith(keys):
items = line.split(":", 1)
Expand All @@ -496,11 +500,11 @@ def from_apt_cache(

epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"])
pkg = DebianPackage(
vals["Package"],
split_version,
epoch,
vals["Architecture"],
PackageState.Available,
name=vals["Package"],
version=split_version,
epoch=epoch,
arch=vals["Architecture"],
state=PackageState.Available,
Comment on lines -499 to +507
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just using keyword arguments because why not

)

if (pkg.arch == "all" or pkg.arch == arch) and (
Expand Down Expand Up @@ -555,15 +559,15 @@ def _get_parts(self, version: str) -> tuple[str, str]:
upstream, debian = version.rsplit("-", 1)
return upstream, debian

def _listify(self, revision: str) -> list[str]:
"""Split a revision string into a listself.
def _listify(self, revision: str) -> list[str | int]:
"""Split a revision string into a list.
Comment on lines -558 to +563
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type is a list that alternates between strings and ints. This will come up later


This list is comprised of alternating between strings and numbers,
padded on either end to always be "str, int, str, int..." and
always be of even length. This allows us to trivially implement the
comparison algorithm described.
"""
result = []
result: list[str | int] = []
while revision:
rev_1, remains = self._get_alphas(revision)
rev_2, remains = self._get_digits(remains)
Expand Down Expand Up @@ -596,7 +600,7 @@ def _get_digits(self, revision: str) -> tuple[int, str]:
# string is entirely digits
return int(revision), ""

def _dstringcmp(self, a, b): # noqa: C901
def _dstringcmp(self, a: str, b: str) -> Literal[-1, 0, 1]:
"""Debian package version string section lexical sort algorithm.

The lexical comparison is a comparison of ASCII values modified so
Expand Down Expand Up @@ -627,15 +631,18 @@ def _dstringcmp(self, a, b): # noqa: C901
return -1
except IndexError:
# a is longer than b but otherwise equal, greater unless there are tildes
if char == "~":
# FIXME: type checker thinks "char" is possibly unbound as it's a loop variable
# but it won't be since the IndexError can only occur inside the loop
# -- I'd like to refactor away this `try ... except` anyway
if char == "~": # pyright: ignore[reportPossiblyUnboundVariable]
return -1
return 1
# if we get here, a is shorter than b but otherwise equal, so check for tildes...
if b[len(a)] == "~":
return 1
return -1

def _compare_revision_strings(self, first: str, second: str): # noqa: C901
def _compare_revision_strings(self, first: str, second: str) -> Literal[-1, 0, 1]:
"""Compare two debian revision strings."""
if first == second:
return 0
Expand All @@ -651,31 +658,38 @@ def _compare_revision_strings(self, first: str, second: str): # noqa: C901
# explicitly raise IndexError if we've fallen off the edge of list2
if i >= len(second_list):
raise IndexError
other = second_list[i]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a local variable to allow type narrowing -- and it's more efficient anyway

# if the items are equal, next
if item == second_list[i]:
if item == other:
continue
# numeric comparison
if isinstance(item, int):
if item > second_list[i]:
assert isinstance(other, int)
if item > other:
return 1
if item < second_list[i]:
if item < other:
return -1
else:
# string comparison
return self._dstringcmp(item, second_list[i])
assert isinstance(other, str)
return self._dstringcmp(item, other)
Comment on lines +667 to +675
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is that list of alternating integers and strings -- well a pair of them being compared

except IndexError:
# rev1 is longer than rev2 but otherwise equal, hence greater
# ...except for goddamn tildes
if first_list[len(second_list)][0][0] == "~":
# FIXME: bug?? we return 1 in both cases
# FIXME: first_list[len(second_list)] should be a string, why are we indexing to 0 twice?
if first_list[len(second_list)][0][0] == "~": # type: ignore
return 1
return 1
# rev1 is shorter than rev2 but otherwise equal, hence lesser
# ...except for goddamn tildes
if second_list[len(first_list)][0][0] == "~":
# FIXME: bug?? we return -1 in both cases
# FIXME: first_list[len(second_list)] should be a string, why are we indexing to 0 twice?
if second_list[len(first_list)][0][0] == "~": # type: ignore
return -1
return -1
Comment on lines 676 to 690
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type: ignore due to the str | int ambiguity -- and this is weird and potentially buggy so I'd like to save further changes for a separate PR (unit test coverage of these comparison functions is spotty)


def _compare_version(self, other) -> int:
def _compare_version(self, other: Version) -> Literal[-1, 0, 1]:
if (self.number, self.epoch) == (other.number, other.epoch):
return 0

Expand All @@ -698,36 +712,52 @@ def _compare_version(self, other) -> int:

return 0

def __lt__(self, other) -> bool:
def __lt__(self, other: Version) -> bool:
"""Less than magic method impl."""
return self._compare_version(other) < 0

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Equality magic method impl."""
if not isinstance(other, Version):
return False
return self._compare_version(other) == 0
Comment on lines -705 to 723
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically a behaviour change, so maybe this should be dropped -- maybe it's better to get an AttributeError when you try to compare a Version for equality with something else?


def __gt__(self, other) -> bool:
def __gt__(self, other: Version) -> bool:
"""Greater than magic method impl."""
return self._compare_version(other) > 0

def __le__(self, other) -> bool:
def __le__(self, other: Version) -> bool:
"""Less than or equal to magic method impl."""
return self.__eq__(other) or self.__lt__(other)

def __ge__(self, other) -> bool:
def __ge__(self, other: Version) -> bool:
"""Greater than or equal to magic method impl."""
return self.__gt__(other) or self.__eq__(other)

def __ne__(self, other) -> bool:
def __ne__(self, other: object) -> bool:
"""Not equal to magic method impl."""
return not self.__eq__(other)


@typing.overload
def add_package(
package_names: str,
version: str | None = "",
arch: str | None = "",
update_cache: bool = False,
) -> DebianPackage: ...
@typing.overload
def add_package(
package_names: list[str],
version: str | None = "",
arch: str | None = "",
update_cache: bool = False,
) -> DebianPackage | list[DebianPackage]: ...
def add_package(
package_names: str | list[str],
version: str | None = "",
arch: str | None = "",
update_cache: bool | None = False,
update_cache: bool = False,
) -> DebianPackage | list[DebianPackage]:
"""Add a package or list of packages to the system.

Expand All @@ -748,8 +778,6 @@ def add_package(
update()
cache_refreshed = True

packages = {"success": [], "retry": [], "failed": []}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separate these into individual variables for better typing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simpler too!


package_names = [package_names] if isinstance(package_names, str) else package_names
if not package_names:
raise TypeError("Expected at least one package name to add, received zero!")
Expand All @@ -759,36 +787,40 @@ def add_package(
"Explicit version should not be set if more than one package is being added!"
)

succeeded: list[DebianPackage] = []
retry: list[str] = []
failed: list[str] = []

for p in package_names:
pkg, success = _add(p, version, arch)
if success:
packages["success"].append(pkg)
pkg, _ = _add(p, version, arch)
if isinstance(pkg, DebianPackage):
succeeded.append(pkg)
else:
logger.warning("failed to locate and install/update '%s'", pkg)
packages["retry"].append(p)
retry.append(p)

if packages["retry"] and not cache_refreshed:
if retry and not cache_refreshed:
logger.info("updating the apt-cache and retrying installation of failed packages.")
update()

for p in packages["retry"]:
pkg, success = _add(p, version, arch)
if success:
packages["success"].append(pkg)
for p in retry:
pkg, _ = _add(p, version, arch)
if isinstance(pkg, DebianPackage):
succeeded.append(pkg)
else:
packages["failed"].append(p)
failed.append(p)

if packages["failed"]:
raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"])))
if failed:
raise PackageError(f"Failed to install packages: {', '.join(failed)}")

return packages["success"] if len(packages["success"]) > 1 else packages["success"][0]
return succeeded if len(succeeded) > 1 else succeeded[0]


def _add(
name: str,
version: str | None = "",
arch: str | None = "",
) -> tuple[DebianPackage | str, bool]:
) -> tuple[DebianPackage, Literal[True]] | tuple[str, Literal[False]]:
"""Add a package to the system.

Args:
Expand All @@ -807,6 +839,14 @@ def _add(
return name, False


@typing.overload
def remove_package(
package_names: str,
) -> DebianPackage: ...
@typing.overload
def remove_package(
package_names: list[str],
) -> DebianPackage | list[DebianPackage]: ...
def remove_package(
package_names: str | list[str],
) -> DebianPackage | list[DebianPackage]:
Expand Down Expand Up @@ -1117,7 +1157,9 @@ def _get_keyid_by_gpg_key(key_material: bytes) -> str:
if "gpg: no valid OpenPGP data found." in err:
raise GPGKeyError("Invalid GPG key material provided")
# from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10)
return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1)
result = re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE)
assert result is not None
return result.group(1)

@staticmethod
def _get_key_by_keyid(keyid: str) -> str:
Expand Down Expand Up @@ -1271,7 +1313,7 @@ def __len__(self) -> int:
"""Return number of repositories in map."""
return len(self._repository_map)

def __iter__(self) -> Iterator[DebianRepository]:
def __iter__(self) -> Iterator[DebianRepository]: # pyright: ignore[reportIncompatibleMethodOverride]
"""Return iterator for RepositoryMapping.

Iterates over the DebianRepository values rather than the string names.
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ commands =
#pyright {[vars]lib_dir} {posargs} # FIXME: enable static analysis of all libs
#pyright {[vars]test_dir} {posargs} # FIXME: enable static analysis of tests
echo 'NOTE: Only manually specified libs are being static type checked currently'
pyright lib/charms/operator_libs_linux/v0/apt.py {posargs}
pyright lib/charms/operator_libs_linux/v2/snap.py {posargs}

[testenv:unit]
Expand Down