diff --git a/portablemc/download.py b/portablemc/download.py index 59943fc0..776b0c27 100644 --- a/portablemc/download.py +++ b/portablemc/download.py @@ -54,16 +54,18 @@ class _DownloadEntry: unsupported URL schemes. """ - __slots__ = "https", "host", "port", "entry" + __slots__ = "https", "host", "port", "path", "entry", "redirect" - def __init__(self, https: bool, host: str, port: Optional[int], entry: DownloadEntry) -> None: + def __init__(self, https: bool, host: str, port: Optional[int], path: str, entry: DownloadEntry, *, redirect: int = 0) -> None: self.https = https self.host = host self.port = port + self.path = path self.entry = entry + self.redirect = redirect @classmethod - def from_entry(cls, entry: DownloadEntry) -> "_DownloadEntry": + def from_entry(cls, entry: DownloadEntry, *, redirect: int = 0) -> "_DownloadEntry": # We only support HTTP/HTTPS url_parsed = urllib.parse.urlparse(entry.url) @@ -74,7 +76,9 @@ def from_entry(cls, entry: DownloadEntry) -> "_DownloadEntry": url_parsed.scheme == "https", url_parsed.netloc, url_parsed.port, - entry) + url_parsed.path, + entry, + redirect=redirect) class DownloadResult: @@ -103,6 +107,7 @@ class DownloadResultError(DownloadResult): """ CONNECTION = "connection" + TOO_MANY_REDIRECT = "too_many_redirect" NOT_FOUND = "not_found" INVALID_SIZE = "invalid_size" INVALID_SHA1 = "invalid_sha1" @@ -277,6 +282,7 @@ def _download_thread( # Maximum tries count or a single entry. max_try_count = 3 + max_redirect = 10 # For speed calculation. speed_update_interval = 0.25 @@ -325,7 +331,7 @@ def _download_thread( # This try-except block is around all potential try: - conn.request("GET", entry.url) + conn.request("GET", raw_entry.path) res = conn.getresponse() if res.status != 200: @@ -334,8 +340,13 @@ def _download_thread( # and allow further request. while res.readinto(buffer): pass - + if res.status == 301 or res.status == 302: + + if raw_entry.redirect >= max_redirect: + last_error = DownloadResultError.TOO_MANY_REDIRECT + continue + # If location header is absent, consider it not found. redirect_url = res.headers.get("location") if redirect_url is not None: @@ -346,7 +357,7 @@ def _download_thread( sha1=entry.sha1, name=entry.name) - entries_queue.put(_DownloadEntry.from_entry(redirect_entry)) + entries_queue.put(_DownloadEntry.from_entry(redirect_entry, redirect=raw_entry.redirect + 1)) break # Abort on redirect # Any other non-200 code is considered not found and we retry...