diff --git a/aws_pcluster_dask_gateway/__init__.py b/aws_pcluster_dask_gateway/__init__.py index 402c3a9..f74a120 100644 --- a/aws_pcluster_dask_gateway/__init__.py +++ b/aws_pcluster_dask_gateway/__init__.py @@ -1,10 +1,13 @@ """Top-level package for aws_pcluster_dask_gateway.""" __author__ = """Jillian Rowe""" -__email__ = 'jillian@dabbleofdevops.com' -__version__ = '0.1.0' +__email__ = "jillian@dabbleofdevops.com" +__version__ = "0.1.0" -from aws_pcluster_dask_gateway.aws_pcluster_dask_gateway import PClusterBackend +from aws_pcluster_dask_gateway.aws_pcluster_dask_gateway import ( + PClusterBackend, DaskGatewaySlurmConfig +) from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] diff --git a/aws_pcluster_dask_gateway/_version.py b/aws_pcluster_dask_gateway/_version.py index 7c611d7..a099ce1 100644 --- a/aws_pcluster_dask_gateway/_version.py +++ b/aws_pcluster_dask_gateway/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -68,12 +67,14 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate @@ -100,10 +101,14 @@ def run_command( try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError as e: if e.errno == errno.ENOENT: @@ -141,15 +146,21 @@ def versions_from_parentdir( for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -212,7 +223,7 @@ def git_versions_from_keywords( # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -221,7 +232,7 @@ def git_versions_from_keywords( # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -229,32 +240,36 @@ def git_versions_from_keywords( for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. @@ -273,8 +288,7 @@ def git_pieces_from_vcs( env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -282,10 +296,19 @@ def git_pieces_from_vcs( # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -300,8 +323,7 @@ def git_pieces_from_vcs( pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -341,17 +363,16 @@ def git_pieces_from_vcs( dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -360,10 +381,12 @@ def git_pieces_from_vcs( if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -412,8 +435,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str: rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -442,8 +464,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str: rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -604,11 +625,13 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str: def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -632,9 +655,13 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions() -> Dict[str, Any]: @@ -648,8 +675,7 @@ def get_versions() -> Dict[str, Any]: verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -658,13 +684,16 @@ def get_versions() -> Dict[str, Any]: # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -678,6 +707,10 @@ def get_versions() -> Dict[str, Any]: except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/aws_pcluster_dask_gateway/aws_pcluster_dask_gateway.py b/aws_pcluster_dask_gateway/aws_pcluster_dask_gateway.py index 415a7e4..a35c4f5 100644 --- a/aws_pcluster_dask_gateway/aws_pcluster_dask_gateway.py +++ b/aws_pcluster_dask_gateway/aws_pcluster_dask_gateway.py @@ -1,7 +1,7 @@ from __future__ import annotations import shutil -from aws_pcluster_helpers.models.sinfo import (SInfoTable, SinfoRow) +from aws_pcluster_helpers.models.sinfo import SInfoTable, SinfoRow import functools import os @@ -27,7 +27,7 @@ InstanceTypesMappings, ) -from aws_pcluster_helpers.models import sinfo +from aws_pcluster_helpers.models.sinfo import SInfoTable, SinfoRow from aws_pcluster_helpers import ( PClusterConfig, InstanceTypesData, @@ -41,13 +41,17 @@ from traitlets import Unicode, default from traitlets import Bool, Float, Integer, List, Unicode, default, validate -from dask_gateway_server.backends.jobqueue.slurm import SlurmClusterConfig, SlurmBackend, slurm_format_memory -from dask_gateway_server.options import Options, Select from dask_gateway_server.options import Options, Select, String from dask_gateway_server.traitlets import Command, Type from pydantic import BaseModel, computed_field -logger = setup_logger('dask-gateway') +from aws_pcluster_dask_gateway.dask_gateway_extensions.backends.jobqueue.slurm import ( + SlurmClusterConfig, + SlurmBackend, + slurm_format_memory, +) + +logger = setup_logger("dask-gateway") """ Docs @@ -58,7 +62,7 @@ """ -class DaskGatewaySlurmConfig(sinfo.SInfoTable): +class DaskGatewaySlurmConfig(BaseModel): """ Configure the Dask Gateway Cluster Each partition/instance type gets its own profile @@ -82,12 +86,14 @@ def options_handler(options): "worker_memory": int(options.worker_memory * 2 ** 30) } """ + pcluster_config_files: PClusterConfigFiles = PClusterConfigFiles() @computed_field @property def profiles(self) -> Dict[str, Dict[str, Any]]: profiles = {} - for sinfo_row in self.rows: + sinfo_table = SInfoTable(pcluster_config_files=self.pcluster_config_files) + for sinfo_row in sinfo_table.rows: label = f"P: {sinfo_row.queue}, I: {sinfo_row.ec2_instance_type}, CPU: {sinfo_row.vcpu}, Mem: {sinfo_row.mem}" memory = sinfo_row.mem / sinfo_row.vcpu # Using all of the available memory is very error prone @@ -110,21 +116,32 @@ def __post_init__(self): class PClusterConfig(SlurmClusterConfig): """Dask cluster configuration options when running on SLURM""" + partition = Unicode("", help="The partition to submit jobs to.", config=True) qos = Unicode("", help="QOS string associated with each job.", config=True) account = Unicode("", help="Account string associated with each job.", config=True) constraint = Unicode("", help="The job instance type constraint.", config=True) - wall_time = Unicode("", help="The walltime. The cluster will be brought down after the wall time is complete.", - config=True) + wall_time = Unicode( + "", + help="The walltime. The cluster will be brought down after the wall time is complete.", + config=True, + ) scheduler_cmd = Command( - shutil.which("dask-scheduler"), help="Shell command to start a dask scheduler.", config=True + shutil.which("dask-scheduler"), + help="Shell command to start a dask scheduler.", + config=True, ) worker_cmd = Command( - shutil.which("dask-worker"), help="Shell command to start a dask worker.", config=True + shutil.which("dask-worker"), + help="Shell command to start a dask worker.", + config=True, ) -def get_cluster_options(default_profile=None): +def get_cluster_options( + default_profile=None, + pcluster_config_files: Optional[PClusterConfigFiles] = None +): """ In your dask_gateway_config.py set the cluster options to cluster_options @@ -134,7 +151,10 @@ def get_cluster_options(default_profile=None): :param default_profile: :return: """ - dask_gateway_slurm_config = DaskGatewaySlurmConfig() + if pcluster_config_files: + dask_gateway_slurm_config = DaskGatewaySlurmConfig(pcluster_config_files=pcluster_config_files) + else: + dask_gateway_slurm_config = DaskGatewaySlurmConfig() profile_names = list(dask_gateway_slurm_config.profiles.keys()) if not default_profile: default_profile = profile_names[0] @@ -147,19 +167,22 @@ def get_cluster_options(default_profile=None): default=default_profile, label="Cluster Profile", ), - String( - "environment", - label="Conda Environment" - ), + String("environment", label="Conda Environment"), handler=lambda options: dask_gateway_slurm_config.profiles[options.profile], ) class PClusterBackend(SlurmBackend): - cluster_options = get_cluster_options() + + # make sure to keep this as @property so its deferred + @property + def cluster_options(self) -> Any: + return get_cluster_options() + dask_gateway_jobqueue_launcher = Unicode( - shutil.which('dask-gateway-jobqueue-launcher'), - help="The path to the dask-gateway-jobqueue-launcher executable", config=True + shutil.which("dask-gateway-jobqueue-launcher"), + help="The path to the dask-gateway-jobqueue-launcher executable", + config=True, ) cluster_start_timeout = Float( 3600, @@ -225,7 +248,7 @@ def get_submit_cmd_env_stdin(self, cluster, worker=None): cmd.append("--constraint=" + str(cluster.config.constraint)) if worker: - logger.info('Configuring dask-gateway worker') + logger.info("Configuring dask-gateway worker") cpus = cluster.config.worker_cores mem = slurm_format_memory(cluster.config.worker_memory) log_file = "dask-worker-%s.log" % worker.name @@ -238,7 +261,7 @@ def get_submit_cmd_env_stdin(self, cluster, worker=None): ) env = self.get_worker_env(cluster) else: - logger.info('Configuring dask-gateway scheduler') + logger.info("Configuring dask-gateway scheduler") cpus = cluster.config.scheduler_cores mem = slurm_format_memory(cluster.config.worker_memory) log_file = "dask-scheduler-%s.log" % cluster.name @@ -264,7 +287,7 @@ def get_submit_cmd_env_stdin(self, cluster, worker=None): ] ) - logger.info(f'Cmd: {cmd}') - logger.info(f'Env: {env}') - logger.info(f'Script: {script}') + logger.info(f"Cmd: {cmd}") + logger.info(f"Env: {env}") + logger.info(f"Script: {script}") return cmd, env, script diff --git a/aws_pcluster_dask_gateway/cli.py b/aws_pcluster_dask_gateway/cli.py index 028fa5c..60135b8 100644 --- a/aws_pcluster_dask_gateway/cli.py +++ b/aws_pcluster_dask_gateway/cli.py @@ -6,8 +6,10 @@ @click.command() def main(args=None): """Console script for aws_pcluster_dask_gateway.""" - click.echo("Replace this message by putting your code into " - "aws_pcluster_dask_gateway.cli.main") + click.echo( + "Replace this message by putting your code into " + "aws_pcluster_dask_gateway.cli.main" + ) click.echo("See click documentation at https://click.palletsprojects.com/") return 0 diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/__init__.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/__init__.py new file mode 100644 index 0000000..a7da247 --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/__init__.py @@ -0,0 +1 @@ +from .base import Backend, ClusterConfig diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/base.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/base.py new file mode 100644 index 0000000..8a9a4d3 --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/base.py @@ -0,0 +1,481 @@ +import asyncio + +import aiohttp +from traitlets import ( + Dict, + Float, + Instance, + Integer, + Unicode, + Union, + default, + observe, + validate, +) +from traitlets.config import Configurable, LoggingConfigurable + +from dask_gateway_server import models +from dask_gateway_server.options import Options +from dask_gateway_server.traitlets import Callable, Command, MemoryLimit, Type +from dask_gateway_server.utils import awaitable, format_bytes + +# from .. import models +# from ..options import Options +# from ..traitlets import Callable, Command, MemoryLimit, Type +# from ..utils import awaitable, format_bytes + + +__all__ = ("Backend", "ClusterConfig") + + +class PublicException(Exception): + """An exception that can be reported to the user""" + + pass + + +class Backend(LoggingConfigurable): + """Base class for defining dask-gateway backends. + + Subclasses should implement the following methods: + + - ``setup`` + - ``cleanup`` + - ``start_cluster`` + - ``stop_cluster`` + - ``on_cluster_heartbeat`` + """ + + cluster_options = Union( + [Callable(), Instance(Options, args=())], + help=""" + User options for configuring an individual cluster. + + Allows users to specify configuration overrides when creating a new + cluster. See the documentation for more information: + + :doc:`cluster-options`. + """, + config=True, + ) + + cluster_config_class = Type( + "dask_gateway_server.backends.base.ClusterConfig", + klass="dask_gateway_server.backends.base.ClusterConfig", + help="The cluster config class to use", + config=True, + ) + + scheduler_api_retries = Integer( + 3, + min=0, + help=""" + The number of attempts to make when contacting a scheduler api endpoint. + + If failures occur after the max number of retries, the dask cluster will + be marked as failed and will be cleaned up. + """, + ) + + api_url = Unicode( + help=""" + The address that internal components (e.g. dask clusters) + will use when contacting the gateway. + """, + config=True, + ) + + # Forwarded from the main application + gateway_address = Unicode() + + @default("gateway_address") + def _gateway_address_default(self): + return self.parent.address + + async def get_cluster_options(self, user): + if callable(self.cluster_options): + return await awaitable(self.cluster_options(user)) + return self.cluster_options + + async def process_cluster_options(self, user, request): + try: + cluster_options = await self.get_cluster_options(user) + requested_options = cluster_options.parse_options(request) + overrides = cluster_options.get_configuration(requested_options, user) + config = self.cluster_config_class(parent=self, **overrides) + except asyncio.CancelledError: + raise + except Exception as exc: + raise PublicException(str(exc)) + return requested_options, config + + async def forward_message_to_scheduler(self, cluster, msg): + if cluster.status != models.ClusterStatus.RUNNING: + raise PublicException(f"cluster {cluster.name} is not running") + attempt = 1 + t = 0.1 + while True: + try: + await self.session.post( + cluster.api_address + "/api/comm", + json=msg, + headers={"Authorization": "token %s" % cluster.token}, + raise_for_status=True, + ) + return + except Exception: + if attempt < self.scheduler_api_retries: + self.log.warning( + f"Failed to message cluster {cluster.name} on attempt " + f"{attempt}, retrying in {t} s", + exc_info=True, + ) + await asyncio.sleep(t) + attempt += 1 + t = min(t * 2, 5) + else: + break + self.log.warning( + f"Failed to message cluster {cluster.name} on attempt " + f"{attempt}, marking cluster as failed" + ) + await self.stop_cluster(cluster.name, failed=True) + raise PublicException(f"cluster {cluster.name} is not running") + + async def setup(self, app): + """Called when the server is starting up. + + Do any initialization here. + + Parameters + ---------- + app : aiohttp.web.Application + The aiohttp application. Can be used to add additional routes if + needed. + """ + self.session = aiohttp.ClientSession() + + async def cleanup(self): + """Called when the server is shutting down. + + Do any cleanup tasks in this method""" + await self.session.close() + + async def list_clusters(self, username=None, statuses=None): + """List known clusters. + + Parameters + ---------- + username : str, optional + A user name to filter on. If not provided, defaults to + all users. + statuses : list, optional + A list of statuses to filter on. If not provided, defaults to all + running and pending clusters. + + Returns + ------- + clusters : List[Cluster] + """ + raise NotImplementedError + + async def get_cluster(self, cluster_name, wait=False): + """Get information about a cluster. + + Parameters + ---------- + cluster_name : str + The cluster name. + wait : bool, optional + If True, wait until the cluster has started or failed before + returning. If waiting is not possible (or waiting for a long period + of time would be expensive) it is valid to return early with a + Cluster object in a state prior to RUNNING (the client will retry + in this case). Default is False. + + Returns + ------- + cluster : Cluster + """ + raise NotImplementedError + + async def start_cluster(self, user, cluster_options): + """Submit a new cluster. + + Parameters + ---------- + user : User + The user making the request. + cluster_options : dict + Any additional options provided by the user. + + Returns + ------- + cluster_name : str + """ + raise NotImplementedError + + async def stop_cluster(self, cluster_name, failed=False): + """Stop a cluster. + + No-op if the cluster is already stopped. + + Parameters + ---------- + cluster_name : str + The cluster name. + failed : bool, optional + If True, the cluster should be marked as FAILED after stopping. If + False (default) it should be marked as STOPPED. + """ + raise NotImplementedError + + async def on_cluster_heartbeat(self, cluster_name, msg): + """Handle a cluster heartbeat. + + Parameters + ---------- + cluster_name : str + The cluster name. + msg : dict + The heartbeat message. + """ + raise NotImplementedError + + +class ClusterConfig(Configurable): + """Base class for holding individual Dask cluster configurations""" + + scheduler_cmd = Command( + "dask-scheduler", help="Shell command to start a dask scheduler.", config=True + ) + + worker_cmd = Command( + "dask-worker", help="Shell command to start a dask worker.", config=True + ) + + environment = Dict( + help=""" + Environment variables to set for both the worker and scheduler processes. + """, + config=True, + ) + + worker_memory = MemoryLimit( + "2 G", + help=""" + Number of bytes available for a dask worker. Allows the following + suffixes: + + - K -> Kibibytes + - M -> Mebibytes + - G -> Gibibytes + - T -> Tebibytes + """, + config=True, + ) + + worker_cores = Integer( + 1, + min=1, + help=""" + Number of cpu-cores available for a dask worker. + """, + config=True, + ) + + # Number of threads per worker. Defaults to the number of cores + worker_threads = Integer( + help=""" + Number of threads available for a dask worker. + + Defaults to ``worker_cores``. + """, + min=1, + config=True, + allow_none=True, + ) + + @default("worker_threads") + def _default_worker_threads(self): + return max(int(self.worker_cores), 1) + + @validate("worker_threads") + def _validate_worker_threads(self, proposal): + if not proposal.value: + return self._default_worker_threads() + return proposal.value + + scheduler_memory = MemoryLimit( + "2 G", + help=""" + Number of bytes available for a dask scheduler. Allows the following + suffixes: + + - K -> Kibibytes + - M -> Mebibytes + - G -> Gibibytes + - T -> Tebibytes + """, + config=True, + ) + + scheduler_cores = Integer( + 1, + min=1, + help=""" + Number of cpu-cores available for a dask scheduler. + """, + config=True, + ) + + adaptive_period = Float( + 3, + min=0, + help=""" + Time (in seconds) between adaptive scaling checks. + + A smaller period will decrease scale up/down latency when responding to + cluster load changes, but may also result in higher load on the gateway + server. + """, + config=True, + ) + + idle_timeout = Float( + 0, + min=0, + help=""" + Time (in seconds) before an idle cluster is automatically shutdown. + + Set to 0 (default) for no idle timeout. + """, + config=True, + ) + + cluster_max_memory = MemoryLimit( + None, + help=""" + The maximum amount of memory (in bytes) available to this cluster. + Allows the following suffixes: + + - K -> Kibibytes + - M -> Mebibytes + - G -> Gibibytes + - T -> Tebibytes + + Set to ``None`` for no memory limit (default). + """, + min=0, + allow_none=True, + config=True, + ) + + cluster_max_cores = Float( + None, + help=""" + The maximum number of cores available to this cluster. + + Set to ``None`` for no cores limit (default). + """, + min=0.0, + allow_none=True, + config=True, + ) + + cluster_max_workers = Integer( + help=""" + The maximum number of workers available to this cluster. + + Note that this will be combined with ``cluster_max_cores`` and + ``cluster_max_memory`` at runtime to determine the actual maximum + number of workers available to this cluster. + """, + allow_none=True, + min=0, + config=True, + ) + + def _check_scheduler_memory(self, scheduler_memory, cluster_max_memory): + if cluster_max_memory is not None and scheduler_memory > cluster_max_memory: + memory = format_bytes(scheduler_memory) + limit = format_bytes(cluster_max_memory) + raise ValueError( + f"Scheduler memory request of {memory} exceeds cluster memory " + f"limit of {limit}" + ) + + def _check_scheduler_cores(self, scheduler_cores, cluster_max_cores): + if cluster_max_cores is not None and scheduler_cores > cluster_max_cores: + raise ValueError( + f"Scheduler cores request of {scheduler_cores} exceeds cluster " + f"cores limit of {cluster_max_cores}" + ) + + def _worker_limit_from_resources(self): + inf = max_workers = float("inf") + if self.cluster_max_memory is not None: + max_workers = min( + (self.cluster_max_memory - self.scheduler_memory) // self.worker_memory, + max_workers, + ) + if self.cluster_max_cores is not None: + max_workers = min( + (self.cluster_max_cores - self.scheduler_cores) // self.worker_cores, + max_workers, + ) + + if max_workers == inf: + return None + return max(0, int(max_workers)) + + @validate("scheduler_memory") + def _validate_scheduler_memory(self, proposal): + self._check_scheduler_memory(proposal.value, self.cluster_max_memory) + return proposal.value + + @validate("scheduler_cores") + def _validate_scheduler_cores(self, proposal): + self._check_scheduler_cores(proposal.value, self.cluster_max_cores) + return proposal.value + + @validate("cluster_max_memory") + def _validate_cluster_max_memory(self, proposal): + self._check_scheduler_memory(self.scheduler_memory, proposal.value) + return proposal.value + + @validate("cluster_max_cores") + def _validate_cluster_max_cores(self, proposal): + self._check_scheduler_cores(self.scheduler_cores, proposal.value) + return proposal.value + + @validate("cluster_max_workers") + def _validate_cluster_max_workers(self, proposal): + lim = self._worker_limit_from_resources() + if lim is None: + return proposal.value + if proposal.value is None: + return lim + return min(proposal.value, lim) + + @observe("cluster_max_workers") + def _observe_cluster_max_workers(self, change): + # This shouldn't be needed, but traitlet validators don't run + # if a value is `None` and `allow_none` is true, so we need to + # add an observer to handle the event of an *explicit* `None` + # set for `cluster_max_workers` + if change.new is None: + lim = self._worker_limit_from_resources() + if lim is not None: + self.cluster_max_workers = lim + + @default("cluster_max_workers") + def _default_cluster_max_workers(self): + return self._worker_limit_from_resources() + + def to_dict(self): + return { + k: getattr(self, k) + for k in self.trait_names() + if k not in {"parent", "config"} + } diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/db_base.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/db_base.py new file mode 100644 index 0000000..d18f85c --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/db_base.py @@ -0,0 +1,1585 @@ +import asyncio +import base64 +import json +import os +import uuid +from collections import defaultdict +from itertools import chain, islice + +import sqlalchemy as sa +from async_timeout import timeout +from cryptography.fernet import Fernet, MultiFernet +from traitlets import Bool, Float, Integer, List, Unicode, default, validate + +from dask_gateway_server import models +from dask_gateway_server.proxy import Proxy +from dask_gateway_server.tls import new_keypair +from dask_gateway_server.utils import ( + Flag, + FrozenAttrDict, + TaskPool, + normalize_address, + timestamp, +) +from dask_gateway_server.workqueue import Backoff, WorkQueue, WorkQueueClosed +from .base import Backend + +__all__ = ("DBBackendBase", "Cluster", "Worker") + + +def _normalize_encrypt_key(key): + if isinstance(key, str): + key = key.encode("ascii") + + if len(key) == 44: + try: + key = base64.urlsafe_b64decode(key) + except ValueError: + pass + + if len(key) == 32: + return base64.urlsafe_b64encode(key) + + raise ValueError( + "All keys in `db_encrypt_keys`/`DASK_GATEWAY_ENCRYPT_KEYS` must be 32 " + "bytes, base64-encoded" + ) + + +def _is_in_memory_db(url): + return url in ("sqlite://", "sqlite:///:memory:") + + +class _IntEnum(sa.TypeDecorator): + impl = sa.Integer + + def __init__(self, enumclass, *args, **kwargs): + super().__init__(*args, **kwargs) + self._enumclass = enumclass + + def process_bind_param(self, value, dialect): + return value.value + + def process_result_value(self, value, dialect): + return self._enumclass(value) + + +class _JSON(sa.TypeDecorator): + "Represents an immutable structure as a json-encoded string." + + impl = sa.LargeBinary + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value).encode("utf-8") + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + + +class JobStatus(models.IntEnum): + CREATED = 1 + SUBMITTED = 2 + RUNNING = 3 + CLOSING = 4 + STOPPED = 5 + FAILED = 6 + + +class Cluster: + """Information on a cluster. + + Not all attributes on this object are publically accessible. When writing a + backend, you may access the following attributes: + + Attributes + ---------- + name : str + The cluster name. + username : str + The user associated with this cluster. + token : str + The API token associated with this cluster. Used to authenticate the + cluster with the gateway. + config : FrozenAttrDict + The serialized ``ClusterConfig`` associated with this cluster. + state : dict + Any cluster state, as yielded from ``do_start_cluster``. + scheduler_address : str + The scheduler address. The empty string if the cluster is not running. + dashboard_address : str + The dashboard address. The empty string if the cluster is not running, + or no dashboard is running on the cluster. + api_address : str + The cluster's api address. The empty string if the cluster is not running. + tls_cert : bytes + The TLS cert credentials associated with the cluster. + tls_key : bytes + The TLS key credentials associated with the cluster. + """ + + def __init__( + self, + id=None, + name=None, + username=None, + token=None, + options=None, + config=None, + status=None, + target=None, + count=0, + state=None, + scheduler_address="", + dashboard_address="", + api_address="", + tls_cert=b"", + tls_key=b"", + start_time=None, + stop_time=None, + ): + self.id = id + self.name = name + self.username = username + self.token = token + self.options = options + self.config = config + self.status = status + self.target = target + self.count = count + self.state = state + self.scheduler_address = scheduler_address + self.dashboard_address = dashboard_address + self.api_address = api_address + self.tls_cert = tls_cert + self.tls_key = tls_key + self.start_time = start_time + self.stop_time = stop_time + + if self.status == JobStatus.RUNNING: + self.last_heartbeat = timestamp() + else: + self.last_heartbeat = None + self.worker_start_failure_count = 0 + self.added_to_proxies = False + self.workers = {} + + self.ready = Flag() + if self.status >= JobStatus.RUNNING: + self.ready.set() + self.shutdown = Flag() + if self.status >= JobStatus.STOPPED: + self.shutdown.set() + + _status_map = { + (JobStatus.CREATED, JobStatus.RUNNING): models.ClusterStatus.PENDING, + (JobStatus.CREATED, JobStatus.CLOSING): models.ClusterStatus.STOPPING, + (JobStatus.CREATED, JobStatus.STOPPED): models.ClusterStatus.STOPPING, + (JobStatus.CREATED, JobStatus.FAILED): models.ClusterStatus.STOPPING, + (JobStatus.SUBMITTED, JobStatus.RUNNING): models.ClusterStatus.PENDING, + (JobStatus.SUBMITTED, JobStatus.CLOSING): models.ClusterStatus.STOPPING, + (JobStatus.SUBMITTED, JobStatus.STOPPED): models.ClusterStatus.STOPPING, + (JobStatus.SUBMITTED, JobStatus.FAILED): models.ClusterStatus.STOPPING, + (JobStatus.RUNNING, JobStatus.RUNNING): models.ClusterStatus.RUNNING, + (JobStatus.RUNNING, JobStatus.CLOSING): models.ClusterStatus.STOPPING, + (JobStatus.RUNNING, JobStatus.STOPPED): models.ClusterStatus.STOPPING, + (JobStatus.RUNNING, JobStatus.FAILED): models.ClusterStatus.STOPPING, + (JobStatus.CLOSING, JobStatus.STOPPED): models.ClusterStatus.STOPPING, + (JobStatus.CLOSING, JobStatus.FAILED): models.ClusterStatus.STOPPING, + (JobStatus.STOPPED, JobStatus.STOPPED): models.ClusterStatus.STOPPED, + (JobStatus.FAILED, JobStatus.FAILED): models.ClusterStatus.FAILED, + } + + def active_workers(self): + return [w for w in self.workers.values() if w.is_active()] + + def is_active(self): + return self.target < JobStatus.STOPPED + + def all_workers_at_least(self, status): + return all(w.status >= status for w in self.workers.values()) + + @property + def model_status(self): + return self._status_map[self.status, self.target] + + def to_model(self): + return models.Cluster( + name=self.name, + username=self.username, + token=self.token, + options=self.options, + config=self.config, + status=self.model_status, + scheduler_address=self.scheduler_address, + dashboard_address=self.dashboard_address, + api_address=self.api_address, + tls_cert=self.tls_cert, + tls_key=self.tls_key, + start_time=self.start_time, + stop_time=self.stop_time, + ) + + +class Worker: + """Information on a worker. + + Not all attributes on this object are publicly accessible. When writing a + backend, you may access the following attributes: + + Attributes + ---------- + name : str + The worker name. + cluster : Cluster + The cluster associated with this worker. + state : dict + Any worker state, as yielded from ``do_start_worker``. + """ + + def __init__( + self, + id=None, + name=None, + cluster=None, + status=None, + target=None, + state=None, + start_time=None, + stop_time=None, + close_expected=False, + ): + self.id = id + self.name = name + self.cluster = cluster + self.status = status + self.target = target + self.state = state + self.start_time = start_time + self.stop_time = stop_time + self.close_expected = close_expected + + def is_active(self): + return self.target < JobStatus.STOPPED + + +metadata = sa.MetaData() + +clusters = sa.Table( + "clusters", + metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.Unicode(255), nullable=False, unique=True), + sa.Column("username", sa.Unicode(255), nullable=False), + sa.Column("status", _IntEnum(JobStatus), nullable=False), + sa.Column("target", _IntEnum(JobStatus), nullable=False), + sa.Column("count", sa.Integer, nullable=False), + sa.Column("options", _JSON, nullable=False), + sa.Column("config", _JSON, nullable=False), + sa.Column("state", _JSON, nullable=False), + sa.Column("token", sa.BINARY(140), nullable=False, unique=True), + sa.Column("scheduler_address", sa.Unicode(255), nullable=False), + sa.Column("dashboard_address", sa.Unicode(255), nullable=False), + sa.Column("api_address", sa.Unicode(255), nullable=False), + sa.Column("tls_credentials", sa.LargeBinary, nullable=False), + sa.Column("start_time", sa.Integer, nullable=False), + sa.Column("stop_time", sa.Integer, nullable=True), +) + +workers = sa.Table( + "workers", + metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.Unicode(255), nullable=False), + sa.Column( + "cluster_id", sa.ForeignKey("clusters.id", ondelete="CASCADE"), nullable=False + ), + sa.Column("status", _IntEnum(JobStatus), nullable=False), + sa.Column("target", _IntEnum(JobStatus), nullable=False), + sa.Column("state", _JSON, nullable=False), + sa.Column("start_time", sa.Integer, nullable=False), + sa.Column("stop_time", sa.Integer, nullable=True), + sa.Column("close_expected", sa.Integer, nullable=False), +) + + +class DataManager: + """Holds the internal state for a single Dask Gateway. + + Keeps the memory representation in-sync with the database. + """ + + def __init__(self, url="sqlite:///:memory:", encrypt_keys=(), **kwargs): + if url.startswith("sqlite"): + kwargs["connect_args"] = {"check_same_thread": False} + + if _is_in_memory_db(url): + kwargs["poolclass"] = sa.pool.StaticPool + self.fernet = None + else: + self.fernet = MultiFernet([Fernet(key) for key in encrypt_keys]) + + engine = sa.create_engine(url, **kwargs) + if url.startswith("sqlite"): + # Register PRAGMA foreigh_keys=on for sqlite + @sa.event.listens_for(engine, "connect") + def connect(dbapi_con, con_record): + cursor = dbapi_con.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + metadata.create_all(engine) + + self.db = engine + + self.username_to_clusters = defaultdict(dict) + self.name_to_cluster = {} + self.id_to_cluster = {} + + # Load all existing clusters into memory + with self.db.begin() as connection: + for c in connection.execute(clusters.select()): + tls_cert, tls_key = self.decode_tls_credentials(c.tls_credentials) + token = self.decode_token(c.token) + cluster = Cluster( + id=c.id, + name=c.name, + username=c.username, + token=token, + options=c.options, + config=FrozenAttrDict(c.config), + status=c.status, + target=c.target, + count=c.count, + state=c.state, + scheduler_address=c.scheduler_address, + dashboard_address=c.dashboard_address, + api_address=c.api_address, + tls_cert=tls_cert, + tls_key=tls_key, + start_time=c.start_time, + stop_time=c.stop_time, + ) + self.username_to_clusters[cluster.username][cluster.name] = cluster + self.id_to_cluster[cluster.id] = cluster + self.name_to_cluster[cluster.name] = cluster + + # Next load all existing workers into memory + for w in connection.execute(workers.select()): + cluster = self.id_to_cluster[w.cluster_id] + worker = Worker( + id=w.id, + name=w.name, + status=w.status, + target=w.target, + cluster=cluster, + state=w.state, + start_time=w.start_time, + stop_time=w.stop_time, + close_expected=w.close_expected, + ) + cluster.workers[worker.name] = worker + + def cleanup_expired(self, max_age_in_seconds): + cutoff = timestamp() - max_age_in_seconds * 1000 + with self.db.begin() as conn: + to_delete = conn.execute( + sa.select(clusters.c.id).where(clusters.c.stop_time < cutoff) + ).fetchall() + + if to_delete: + to_delete = [i for i, in to_delete] + + conn.execute( + clusters.delete().where(clusters.c.id == sa.bindparam("id")), + [{"id": i} for i in to_delete], + ) + + for i in to_delete: + cluster = self.id_to_cluster.pop(i) + self.name_to_cluster.pop(cluster.name, None) + user_clusters = self.username_to_clusters[cluster.username] + user_clusters.pop(cluster.name) + if not user_clusters: + self.username_to_clusters.pop(cluster.username) + + return len(to_delete) + + def encrypt(self, b): + """Encrypt bytes ``b``. If encryption is disabled this is a no-op""" + return b if self.fernet is None else self.fernet.encrypt(b) + + def decrypt(self, b): + """Decrypt bytes ``b``. If encryption is disabled this is a no-op""" + return b if self.fernet is None else self.fernet.decrypt(b) + + def encode_tls_credentials(self, tls_cert, tls_key): + return self.encrypt(b";".join((tls_cert, tls_key))) + + def decode_tls_credentials(self, data): + return self.decrypt(data).split(b";") + + def encode_token(self, token): + return self.encrypt(token.encode("utf8")) + + def decode_token(self, data): + return self.decrypt(data).decode() + + def get_cluster(self, cluster_name): + return self.name_to_cluster.get(cluster_name) + + def list_clusters(self, username=None, statuses=None): + if statuses is None: + select = lambda x: x.is_active() + else: + statuses = set(statuses) + select = lambda x: x.model_status in statuses + if username is None: + return [ + cluster for cluster in self.name_to_cluster.values() if select(cluster) + ] + else: + clusters = self.username_to_clusters.get(username) + if clusters is None: + return [] + return [cluster for cluster in clusters.values() if select(cluster)] + + def active_clusters(self): + for cluster in self.name_to_cluster.values(): + if cluster.is_active(): + yield cluster + + def create_cluster(self, username, options, config): + """Create a new cluster for a user""" + cluster_name = uuid.uuid4().hex + token = uuid.uuid4().hex + tls_cert, tls_key = new_keypair(cluster_name) + # Encode the tls credentials for storing in the database + tls_credentials = self.encode_tls_credentials(tls_cert, tls_key) + enc_token = self.encode_token(token) + + common = { + "name": cluster_name, + "username": username, + "options": options, + "status": JobStatus.CREATED, + "target": JobStatus.RUNNING, + "count": 0, + "state": {}, + "scheduler_address": "", + "dashboard_address": "", + "api_address": "", + "start_time": timestamp(), + } + + with self.db.begin() as conn: + res = conn.execute( + clusters.insert().values( + tls_credentials=tls_credentials, + token=enc_token, + config=config, + **common, + ) + ) + cluster = Cluster( + id=res.inserted_primary_key[0], + token=token, + tls_cert=tls_cert, + tls_key=tls_key, + config=FrozenAttrDict(config), + **common, + ) + self.id_to_cluster[cluster.id] = cluster + self.name_to_cluster[cluster_name] = cluster + self.username_to_clusters[username][cluster_name] = cluster + + return cluster + + def create_worker(self, cluster): + """Create a new worker for a cluster""" + worker_name = uuid.uuid4().hex + + common = { + "name": worker_name, + "status": JobStatus.CREATED, + "target": JobStatus.RUNNING, + "state": {}, + "start_time": timestamp(), + "close_expected": False, + } + + with self.db.begin() as conn: + res = conn.execute(workers.insert().values(cluster_id=cluster.id, **common)) + worker = Worker(id=res.inserted_primary_key[0], cluster=cluster, **common) + cluster.workers[worker.name] = worker + + return worker + + def update_cluster(self, cluster, **kwargs): + """Update a cluster's state""" + with self.db.begin() as conn: + conn.execute( + clusters.update().where(clusters.c.id == cluster.id).values(**kwargs) + ) + for k, v in kwargs.items(): + setattr(cluster, k, v) + + def update_clusters(self, updates): + """Update multiple clusters' states""" + if not updates: + return + with self.db.begin() as conn: + conn.execute( + clusters.update().where(clusters.c.id == sa.bindparam("_id")), + [{"_id": c.id, **u} for c, u in updates], + ) + for c, u in updates: + for k, v in u.items(): + setattr(c, k, v) + + def update_worker(self, worker, **kwargs): + """Update a worker's state""" + with self.db.begin() as conn: + conn.execute( + workers.update().where(workers.c.id == worker.id).values(**kwargs) + ) + for k, v in kwargs.items(): + setattr(worker, k, v) + + def update_workers(self, updates): + """Update multiple workers' states""" + if not updates: + return + with self.db.begin() as conn: + conn.execute( + workers.update().where(workers.c.id == sa.bindparam("_id")), + [{"_id": w.id, **u} for w, u in updates], + ) + for w, u in updates: + for k, v in u.items(): + setattr(w, k, v) + + +class DBBackendBase(Backend): + """A base class for defining backends that rely on a database for managing state. + + Subclasses should define the following methods: + + - ``do_setup`` + - ``do_cleanup`` + - ``do_start_cluster`` + - ``do_stop_cluster`` + - ``do_check_clusters`` + - ``do_start_worker`` + - ``do_stop_worker`` + - ``do_check_workers`` + """ + + db_url = Unicode( + "sqlite:///:memory:", + help=""" + The URL for the database. Default is in-memory only. + + If not in-memory, ``db_encrypt_keys`` must also be set. + """, + config=True, + ) + + db_encrypt_keys = List( + help=""" + A list of keys to use to encrypt private data in the database. Can also + be set by the environment variable ``DASK_GATEWAY_ENCRYPT_KEYS``, where + the value is a ``;`` delimited string of encryption keys. + + Each key should be a base64-encoded 32 byte value, and should be + cryptographically random. Lacking other options, openssl can be used to + generate a single key via: + + .. code-block:: shell + + $ openssl rand -base64 32 + + A single key is valid, multiple keys can be used to support key rotation. + """, + config=True, + ) + + @default("db_encrypt_keys") + def _db_encrypt_keys_default(self): + keys = [ + k.strip() + for k in os.environb.get(b"DASK_GATEWAY_ENCRYPT_KEYS", b"").split(b";") + if k.strip() + ] + return self._db_encrypt_keys_validate({"value": keys}) + + @validate("db_encrypt_keys") + def _db_encrypt_keys_validate(self, proposal): + if not proposal["value"] and not _is_in_memory_db(self.db_url): + raise ValueError( + "Must configure `db_encrypt_keys`/`DASK_GATEWAY_ENCRYPT_KEYS` " + "when not using an in-memory database" + ) + return [_normalize_encrypt_key(k) for k in proposal["value"]] + + db_debug = Bool( + False, help="If True, all database operations will be logged", config=True + ) + + db_cleanup_period = Float( + 600, + help=""" + Time (in seconds) between database cleanup tasks. + + This sets how frequently old records are removed from the database. + This shouldn't be too small (to keep the overhead low), but should be + smaller than ``db_record_max_age`` (probably by an order of magnitude). + """, + config=True, + ) + + db_cluster_max_age = Float( + 3600 * 24, + help=""" + Max time (in seconds) to keep around records of completed clusters. + + Every ``db_cleanup_period``, completed clusters older than + ``db_cluster_max_age`` are removed from the database. + """, + config=True, + ) + + stop_clusters_on_shutdown = Bool( + True, + help=""" + Whether to stop active clusters on gateway shutdown. + + If true, all active clusters will be stopped before shutting down the + gateway. Set to False to leave active clusters running. + """, + config=True, + ) + + @validate("stop_clusters_on_shutdown") + def _stop_clusters_on_shutdown_validate(self, proposal): + if not proposal.value and _is_in_memory_db(self.db_url): + raise ValueError( + "When using an in-memory database, `stop_clusters_on_shutdown` " + "must be True" + ) + return proposal.value + + cluster_status_period = Float( + 30, + help=""" + Time (in seconds) between cluster status checks. + + A smaller period will detect failed clusters sooner, but will use more + resources. A larger period will provide slower feedback in the presence + of failures. + """, + config=True, + ) + + worker_status_period = Float( + 30, + help=""" + Time (in seconds) between worker status checks. + + A smaller period will detect failed workers sooner, but will use more + resources. A larger period will provide slower feedback in the presence + of failures. + """, + config=True, + ) + + cluster_heartbeat_period = Integer( + 15, + help=""" + Time (in seconds) between cluster heartbeats to the gateway. + + A smaller period will detect failed workers sooner, but will use more + resources. A larger period will provide slower feedback in the presence + of failures. + """, + config=True, + ) + + cluster_heartbeat_timeout = Float( + help=""" + Timeout (in seconds) before killing a dask cluster that's failed to heartbeat. + + This should be greater than ``cluster_heartbeat_period``. Defaults to + ``2 * cluster_heartbeat_period``. + """, + config=True, + ) + + @default("cluster_heartbeat_timeout") + def _default_cluster_heartbeat_timeout(self): + return self.cluster_heartbeat_period * 2 + + cluster_start_timeout = Float( + 60, + help=""" + Timeout (in seconds) before giving up on a starting dask cluster. + """, + config=True, + ) + + worker_start_timeout = Float( + 60, + help=""" + Timeout (in seconds) before giving up on a starting dask worker. + """, + config=True, + ) + + check_timeouts_period = Float( + help=""" + Time (in seconds) between timeout checks. + + This shouldn't be too small (to keep the overhead low), but should be + smaller than ``cluster_heartbeat_timeout``, ``cluster_start_timeout``, + and ``worker_start_timeout``. + """, + config=True, + ) + + @default("check_timeouts_period") + def _default_check_timeouts_period(self): + min_timeout = min( + self.cluster_heartbeat_timeout, + self.cluster_start_timeout, + self.worker_start_timeout, + ) + return min(20, min_timeout / 2) + + worker_start_failure_limit = Integer( + 3, + help=""" + A limit on the number of failed attempts to start a worker before the + cluster is marked as failed. + + Every worker that fails to start (timeouts exempt) increments a + counter. The counter is reset if a worker successfully starts. If the + counter ever exceeds this limit, the cluster is marked as failed and is + shutdown. + """, + config=True, + ) + + parallelism = Integer( + 20, + help=""" + Number of handlers to use for starting/stopping clusters. + """, + config=True, + ) + + backoff_base_delay = Float( + 0.1, + help=""" + Base delay (in seconds) for backoff when retrying after failures. + + If an operation fails, it is retried after a backoff computed as: + + ``` + min(backoff_max_delay, backoff_base_delay * 2 ** num_failures) + ``` + """, + config=True, + ) + + backoff_max_delay = Float( + 300, + help=""" + Max delay (in seconds) for backoff policy when retrying after failures. + """, + config=True, + ) + + api_url = Unicode( + help=""" + The address that internal components (e.g. dask clusters) + will use when contacting the gateway. + + Defaults to `{proxy_address}/{prefix}/api`, set manually if a different + address should be used. + """, + config=True, + ) + + @default("api_url") + def _api_url_default(self): + proxy = self.proxy + scheme = "https" if proxy.tls_cert else "http" + address = normalize_address(proxy.address, resolve_host=True) + return f"{scheme}://{address}{proxy.prefix}/api" + + async def setup(self, app): + await super().setup(app) + + # Setup reconcilation queues + self.queue = WorkQueue( + backoff=Backoff( + base_delay=self.backoff_base_delay, max_delay=self.backoff_max_delay + ) + ) + self.reconcilers = [ + asyncio.ensure_future(self.reconciler_loop()) + for _ in range(self.parallelism) + ] + + # Start the proxy + self.proxy = Proxy(parent=self, log=self.log) + await self.proxy.setup(app) + + # Load the database + self.db = DataManager( + url=self.db_url, echo=self.db_debug, encrypt_keys=self.db_encrypt_keys + ) + + # Start background tasks + self.task_pool = TaskPool() + self.task_pool.spawn(self.check_timeouts_loop()) + self.task_pool.spawn(self.check_clusters_loop()) + self.task_pool.spawn(self.check_workers_loop()) + self.task_pool.spawn(self.cleanup_db_loop()) + + # Load all active clusters/workers into reconcilation queues + for cluster in self.db.name_to_cluster.values(): + if cluster.status < JobStatus.STOPPED: + self.queue.put(cluster) + for worker in cluster.workers.values(): + if worker.status < JobStatus.STOPPED: + self.queue.put(worker) + + # Further backend-specific setup + await self.do_setup() + + self.log.info( + "Backend started, clusters will contact api server at %s", self.api_url + ) + + async def cleanup(self): + if hasattr(self, "task_pool"): + # Stop background tasks + await self.task_pool.close() + + if hasattr(self, "db"): + if self.stop_clusters_on_shutdown: + # Request all active clusters be stopped + active = list(self.db.active_clusters()) + if active: + self.log.info("Stopping %d active clusters...", len(active)) + self.db.update_clusters( + [(c, {"target": JobStatus.FAILED}) for c in active] + ) + for c in active: + self.queue.put(c) + + # Wait until all clusters are shutdown + pending_shutdown = [ + c + for c in self.db.name_to_cluster.values() + if c.status < JobStatus.STOPPED + ] + if pending_shutdown: + await asyncio.wait( + [asyncio.ensure_future(c.shutdown) for c in pending_shutdown] + ) + + # Stop reconcilation queues + if hasattr(self, "reconcilers"): + self.queue.close() + await asyncio.gather(*self.reconcilers, return_exceptions=True) + + await self.do_cleanup() + + if hasattr(self, "proxy"): + await self.proxy.cleanup() + + await super().cleanup() + + async def list_clusters(self, username=None, statuses=None): + clusters = self.db.list_clusters(username=username, statuses=statuses) + return [c.to_model() for c in clusters] + + async def get_cluster(self, cluster_name, wait=False): + cluster = self.db.get_cluster(cluster_name) + if cluster is None: + return None + if wait: + try: + await asyncio.wait_for(cluster.ready, 20) + except asyncio.TimeoutError: + pass + return cluster.to_model() + + async def start_cluster(self, user, cluster_options): + options, config = await self.process_cluster_options(user, cluster_options) + cluster = self.db.create_cluster(user.name, options, config.to_dict()) + self.log.info("Created cluster %s for user %s", cluster.name, user.name) + self.queue.put(cluster) + return cluster.name + + async def stop_cluster(self, cluster_name, failed=False): + cluster = self.db.get_cluster(cluster_name) + if cluster is None: + return + if cluster.target <= JobStatus.RUNNING: + self.log.info("Stopping cluster %s", cluster.name) + target = JobStatus.FAILED if failed else JobStatus.STOPPED + self.db.update_cluster(cluster, target=target) + self.queue.put(cluster) + + async def on_cluster_heartbeat(self, cluster_name, msg): + cluster = self.db.get_cluster(cluster_name) + if cluster is None or cluster.target > JobStatus.RUNNING: + return + + cluster.last_heartbeat = timestamp() + + if cluster.status == JobStatus.RUNNING: + cluster_update = {} + else: + cluster_update = { + "api_address": msg["api_address"], + "scheduler_address": msg["scheduler_address"], + "dashboard_address": msg["dashboard_address"], + } + + count = msg["count"] + active_workers = set(msg["active_workers"]) + closing_workers = set(msg["closing_workers"]) + closed_workers = set(msg["closed_workers"]) + + self.log.info( + "Cluster %s heartbeat [count: %d, n_active: %d, n_closing: %d, n_closed: %d]", + cluster_name, + count, + len(active_workers), + len(closing_workers), + len(closed_workers), + ) + + max_workers = cluster.config.get("cluster_max_workers") + if max_workers is not None and count > max_workers: + # This shouldn't happen under normal operation, but could if the + # user does something malicious (or there's a bug). + self.log.info( + "Cluster %s heartbeat requested %d workers, exceeding limit of %s.", + cluster_name, + count, + max_workers, + ) + count = max_workers + + if count != cluster.count: + cluster_update["count"] = count + + created_workers = [] + submitted_workers = [] + target_updates = [] + newly_running = [] + close_expected = [] + for worker in cluster.workers.values(): + if worker.status >= JobStatus.STOPPED: + continue + elif worker.name in closing_workers: + if worker.status < JobStatus.RUNNING: + newly_running.append(worker) + close_expected.append(worker) + elif worker.name in active_workers: + if worker.status < JobStatus.RUNNING: + newly_running.append(worker) + elif worker.name in closed_workers: + target = ( + JobStatus.STOPPED if worker.close_expected else JobStatus.FAILED + ) + target_updates.append((worker, {"target": target})) + else: + if worker.status == JobStatus.SUBMITTED: + submitted_workers.append(worker) + else: + assert worker.status == JobStatus.CREATED + created_workers.append(worker) + + n_pending = len(created_workers) + len(submitted_workers) + n_to_stop = len(active_workers) + n_pending - count + if n_to_stop > 0: + for w in islice(chain(created_workers, submitted_workers), n_to_stop): + target_updates.append((w, {"target": JobStatus.STOPPED})) + + if cluster_update: + self.db.update_cluster(cluster, **cluster_update) + self.queue.put(cluster) + + self.db.update_workers(target_updates) + for w, u in target_updates: + self.queue.put(w) + + if newly_running: + # At least one worker successfully started, reset failure count + cluster.worker_start_failure_count = 0 + self.db.update_workers( + [(w, {"status": JobStatus.RUNNING}) for w in newly_running] + ) + for w in newly_running: + self.log.info("Worker %s is running", w.name) + + self.db.update_workers([(w, {"close_expected": True}) for w in close_expected]) + + async def check_timeouts_loop(self): + while True: + await asyncio.sleep(self.check_timeouts_period) + try: + await self._check_timeouts() + except asyncio.CancelledError: + raise + except Exception as exc: + self.log.warning( + "Exception while checking for timed out clusters/workers", + exc_info=exc, + ) + + async def _check_timeouts(self): + self.log.debug("Checking for timed out clusters/workers") + now = timestamp() + cluster_heartbeat_cutoff = now - self.cluster_heartbeat_timeout * 1000 + cluster_start_cutoff = now - self.cluster_start_timeout * 1000 + worker_start_cutoff = now - self.worker_start_timeout * 1000 + cluster_updates = [] + worker_updates = [] + for cluster in self.db.active_clusters(): + if cluster.status == JobStatus.SUBMITTED: + # Check if submitted clusters have timed out + if cluster.start_time < cluster_start_cutoff: + self.log.info("Cluster %s startup timed out", cluster.name) + cluster_updates.append((cluster, {"target": JobStatus.FAILED})) + elif cluster.status == JobStatus.RUNNING: + # Check if running clusters have missed a heartbeat + if cluster.last_heartbeat < cluster_heartbeat_cutoff: + self.log.info("Cluster %s heartbeat timed out", cluster.name) + cluster_updates.append((cluster, {"target": JobStatus.FAILED})) + else: + for w in cluster.workers.values(): + # Check if submitted workers have timed out + if ( + w.status == JobStatus.SUBMITTED + and w.target == JobStatus.RUNNING + and w.start_time < worker_start_cutoff + ): + self.log.info("Worker %s startup timed out", w.name) + worker_updates.append((w, {"target": JobStatus.FAILED})) + self.db.update_clusters(cluster_updates) + for c, _ in cluster_updates: + self.queue.put(c) + self.db.update_workers(worker_updates) + for w, _ in worker_updates: + self.queue.put(w) + + async def check_clusters_loop(self): + while True: + await asyncio.sleep(self.cluster_status_period) + self.log.debug("Checking pending cluster statuses") + try: + clusters = [ + c + for c in self.db.active_clusters() + if c.status == JobStatus.SUBMITTED + ] + statuses = await self.do_check_clusters(clusters) + updates = [ + (c, {"target": JobStatus.FAILED}) + for c, ok in zip(clusters, statuses) + if not ok + ] + self.db.update_clusters(updates) + for c, _ in updates: + self.log.info("Cluster %s failed during startup", c.name) + self.queue.put(c) + except asyncio.CancelledError: + raise + except Exception as exc: + self.log.warning( + "Exception while checking cluster statuses", exc_info=exc + ) + + async def check_workers_loop(self): + while True: + await asyncio.sleep(self.worker_status_period) + self.log.debug("Checking pending worker statuses") + try: + clusters = ( + c + for c in self.db.active_clusters() + if c.status == JobStatus.RUNNING + ) + workers = [ + w + for c in clusters + for w in c.active_workers() + if w.status == JobStatus.SUBMITTED + ] + statuses = await self.do_check_workers(workers) + updates = [ + (w, {"target": JobStatus.FAILED}) + for w, ok in zip(workers, statuses) + if not ok + ] + self.db.update_workers(updates) + for w, _ in updates: + self.log.info("Worker %s failed during startup", w.name) + w.cluster.worker_start_failure_count += 1 + self.queue.put(w) + except asyncio.CancelledError: + raise + except Exception as exc: + self.log.warning( + "Exception while checking worker statuses", exc_info=exc + ) + + async def cleanup_db_loop(self): + while True: + try: + n = self.db.cleanup_expired(self.db_cluster_max_age) + except Exception as exc: + self.log.error( + "Error while cleaning expired database records", exc_info=exc + ) + else: + self.log.debug("Removed %d expired clusters from the database", n) + await asyncio.sleep(self.db_cleanup_period) + + async def reconciler_loop(self): + while True: + try: + obj = await self.queue.get() + except WorkQueueClosed: + return + + if isinstance(obj, Cluster): + method = self.reconcile_cluster + kind = "cluster" + else: + method = self.reconcile_worker + kind = "worker" + + self.log.debug( + "Reconciling %s %s, %s -> %s", + kind, + obj.name, + obj.status.name, + obj.target.name, + ) + + try: + await method(obj) + except Exception: + self.log.warning( + "Error while reconciling %s %s", kind, obj.name, exc_info=True + ) + self.queue.put_backoff(obj) + else: + self.queue.reset_backoff(obj) + finally: + self.queue.task_done(obj) + + async def reconcile_cluster(self, cluster): + if cluster.status >= JobStatus.STOPPED: + return + + if cluster.target in (JobStatus.STOPPED, JobStatus.FAILED): + if cluster.status == JobStatus.CLOSING: + if self.is_cluster_ready_to_close(cluster): + await self._cluster_to_stopped(cluster) + else: + await self._cluster_to_closing(cluster) + return + + if cluster.target == JobStatus.RUNNING: + if cluster.status == JobStatus.CREATED: + await self._cluster_to_submitted(cluster) + return + + if cluster.status == JobStatus.SUBMITTED and cluster.scheduler_address: + await self._cluster_to_running(cluster) + + if cluster.status == JobStatus.RUNNING: + await self._check_cluster_proxied(cluster) + await self._check_cluster_scale(cluster) + + async def reconcile_worker(self, worker): + if worker.status >= JobStatus.STOPPED: + return + + if worker.target == JobStatus.CLOSING: + if worker.status != JobStatus.CLOSING: + self.db.update_worker(worker, status=JobStatus.CLOSING) + if self.is_cluster_ready_to_close(worker.cluster): + self.queue.put(worker.cluster) + return + + if worker.target in (JobStatus.STOPPED, JobStatus.FAILED): + await self._worker_to_stopped(worker) + if self.is_cluster_ready_to_close(worker.cluster): + self.queue.put(worker.cluster) + elif ( + worker.cluster.target == JobStatus.RUNNING and not worker.close_expected + ): + self.queue.put(worker.cluster) + return + + if worker.status == JobStatus.CREATED and worker.target == JobStatus.RUNNING: + await self._worker_to_submitted(worker) + return + + def is_cluster_ready_to_close(self, cluster): + return ( + cluster.status == JobStatus.CLOSING + and ( + self.supports_bulk_shutdown + and cluster.all_workers_at_least(JobStatus.CLOSING) + ) + or cluster.all_workers_at_least(JobStatus.STOPPED) + ) + + async def _cluster_to_submitted(self, cluster): + self.log.info("Submitting cluster %s...", cluster.name) + try: + async with timeout(self.cluster_start_timeout): + async for state in self.do_start_cluster(cluster): + self.log.debug("State update for cluster %s", cluster.name) + self.db.update_cluster(cluster, state=state) + self.db.update_cluster(cluster, status=JobStatus.SUBMITTED) + self.log.info("Cluster %s submitted", cluster.name) + except asyncio.CancelledError: + raise + except Exception as exc: + if isinstance(exc, asyncio.TimeoutError): + self.log.info("Cluster %s startup timed out", cluster.name) + else: + self.log.warning( + "Failed to submit cluster %s", cluster.name, exc_info=exc + ) + self.db.update_cluster( + cluster, status=JobStatus.SUBMITTED, target=JobStatus.FAILED + ) + self.queue.put(cluster) + + async def _cluster_to_closing(self, cluster): + self.log.debug("Preparing to stop cluster %s", cluster.name) + target = JobStatus.CLOSING if self.supports_bulk_shutdown else JobStatus.STOPPED + workers = [w for w in cluster.workers.values() if w.target < target] + self.db.update_workers([(w, {"target": target}) for w in workers]) + for w in workers: + self.queue.put(w) + self.db.update_cluster(cluster, status=JobStatus.CLOSING) + if not workers: + # If there are workers, the cluster will be enqueued after the last one closed + # If there are no workers, requeue now + self.queue.put(cluster) + cluster.ready.set() + + async def _cluster_to_stopped(self, cluster): + self.log.info("Stopping cluster %s...", cluster.name) + if cluster.status > JobStatus.CREATED: + try: + await self.do_stop_cluster(cluster) + except Exception as exc: + self.log.warning( + "Exception while stopping cluster %s", cluster.name, exc_info=exc + ) + await self.proxy.remove_route(kind="PATH", path=f"/clusters/{cluster.name}") + await self.proxy.remove_route(kind="SNI", sni=cluster.name) + self.log.info("Cluster %s stopped", cluster.name) + self.db.update_workers( + [ + (w, {"status": JobStatus.STOPPED, "target": JobStatus.STOPPED}) + for w in cluster.workers.values() + if w.status < JobStatus.STOPPED + ] + ) + self.db.update_cluster(cluster, status=cluster.target, stop_time=timestamp()) + cluster.ready.set() + cluster.shutdown.set() + + async def _cluster_to_running(self, cluster): + self.log.info("Cluster %s is running", cluster.name) + self.db.update_cluster(cluster, status=JobStatus.RUNNING) + cluster.ready.set() + + async def _check_cluster_proxied(self, cluster): + if not cluster.added_to_proxies: + self.log.info("Adding cluster %s routes to proxies", cluster.name) + if cluster.dashboard_address: + await self.proxy.add_route( + kind="PATH", + path=f"/clusters/{cluster.name}", + target=cluster.dashboard_address, + ) + await self.proxy.add_route( + kind="SNI", sni=cluster.name, target=cluster.scheduler_address + ) + cluster.added_to_proxies = True + + async def _check_cluster_scale(self, cluster): + if cluster.worker_start_failure_count >= self.worker_start_failure_limit: + self.log.info( + "Cluster %s had %d consecutive workers fail to start, failing the cluster", + cluster.name, + cluster.worker_start_failure_count, + ) + self.db.update_cluster(cluster, target=JobStatus.FAILED) + self.queue.put(cluster) + return + + active = cluster.active_workers() + if cluster.count > len(active): + for _ in range(cluster.count - len(active)): + worker = self.db.create_worker(cluster) + self.log.info( + "Created worker %s for cluster %s", worker.name, cluster.name + ) + self.queue.put(worker) + + async def _worker_to_submitted(self, worker): + self.log.info("Submitting worker %s...", worker.name) + try: + async with timeout(self.worker_start_timeout): + async for state in self.do_start_worker(worker): + self.log.debug("State update for worker %s", worker.name) + self.db.update_worker(worker, state=state) + self.db.update_worker(worker, status=JobStatus.SUBMITTED) + self.log.info("Worker %s submitted", worker.name) + except asyncio.CancelledError: + raise + except Exception as exc: + if isinstance(exc, asyncio.TimeoutError): + self.log.info("Worker %s startup timed out", worker.name) + else: + self.log.warning( + "Failed to submit worker %s", worker.name, exc_info=exc + ) + self.db.update_worker( + worker, status=JobStatus.SUBMITTED, target=JobStatus.FAILED + ) + worker.cluster.worker_start_failure_count += 1 + self.queue.put(worker) + + async def _worker_to_stopped(self, worker): + self.log.info("Stopping worker %s...", worker.name) + if worker.status > JobStatus.CREATED: + try: + await self.do_stop_worker(worker) + except Exception as exc: + self.log.warning( + "Exception while stopping worker %s", worker.name, exc_info=exc + ) + self.log.info("Worker %s stopped", worker.name) + self.db.update_worker(worker, status=worker.target, stop_time=timestamp()) + + def get_tls_paths(self, cluster): + """Return the paths to the cert and key files for this cluster""" + return "dask.crt", "dask.pem" + + def get_env(self, cluster): + """Get a dict of environment variables to set for the process""" + out = dict(cluster.config.environment) + # Set values that dask-gateway needs to run + out.update( + { + "DASK_GATEWAY_API_URL": self.api_url, + "DASK_GATEWAY_API_TOKEN": cluster.token, + "DASK_GATEWAY_CLUSTER_NAME": cluster.name, + "DASK_DISTRIBUTED__COMM__REQUIRE_ENCRYPTION": "True", + } + ) + return out + + def get_scheduler_env(self, cluster): + env = self.get_env(cluster) + tls_cert_path, tls_key_path = self.get_tls_paths(cluster) + env.update( + { + "DASK_DISTRIBUTED__COMM__TLS__CA_FILE": tls_cert_path, + "DASK_DISTRIBUTED__COMM__TLS__SCHEDULER__KEY": tls_key_path, + "DASK_DISTRIBUTED__COMM__TLS__SCHEDULER__CERT": tls_cert_path, + } + ) + return env + + def get_worker_env(self, cluster): + env = self.get_env(cluster) + tls_cert_path, tls_key_path = self.get_tls_paths(cluster) + env.update( + { + "DASK_DISTRIBUTED__COMM__TLS__CA_FILE": tls_cert_path, + "DASK_DISTRIBUTED__COMM__TLS__WORKER__KEY": tls_key_path, + "DASK_DISTRIBUTED__COMM__TLS__WORKER__CERT": tls_cert_path, + } + ) + return env + + default_host = "0.0.0.0" + + def get_scheduler_command(self, cluster): + return cluster.config.scheduler_cmd + [ + "--protocol", + "tls", + "--port", + "0", + "--host", + self.default_host, + "--dashboard-address", + f"{self.default_host}:0", + "--preload", + "dask_gateway.scheduler_preload", + "--dg-api-address", + f"{self.default_host}:0", + "--dg-heartbeat-period", + str(self.cluster_heartbeat_period), + "--dg-adaptive-period", + str(cluster.config.adaptive_period), + "--dg-idle-timeout", + str(cluster.config.idle_timeout), + ] + + def worker_nthreads_memory_limit_args(self, cluster): + return str(cluster.config.worker_threads), str(cluster.config.worker_memory) + + def get_worker_command(self, cluster, worker_name, scheduler_address=None): + nthreads, memory_limit = self.worker_nthreads_memory_limit_args(cluster) + if scheduler_address is None: + scheduler_address = cluster.scheduler_address + return cluster.config.worker_cmd + [ + scheduler_address, + "--dashboard-address", + f"{self.default_host}:0", + "--name", + worker_name, + "--nthreads", + nthreads, + "--memory-limit", + memory_limit, + ] + + # Subclasses should implement these methods + supports_bulk_shutdown = False + + async def do_setup(self): + """Called when the server is starting up. + + Do any initialization here. + """ + pass + + async def do_cleanup(self): + """Called when the server is shutting down. + + Do any cleanup here.""" + pass + + async def do_start_cluster(self, cluster): + """Start a cluster. + + This should do any initialization for the whole dask cluster + application, and then start the scheduler. + + Parameters + ---------- + cluster : Cluster + Information on the cluster to be started. + + Yields + ------ + cluster_state : dict + Any state needed for further interactions with this cluster. This + should be serializable using ``json.dumps``. If startup occurs in + multiple stages, can iteratively yield state updates to be + checkpointed. If an error occurs at any time, the last yielded + state will be used when calling ``do_stop_cluster``. + """ + raise NotImplementedError + + async def do_stop_cluster(self, cluster): + """Stop a cluster. + + Parameters + ---------- + cluster : Cluster + Information on the cluster to be stopped. + """ + raise NotImplementedError + + async def do_check_clusters(self, clusters): + """Check the status of multiple clusters. + + This is periodically called to check the status of pending clusters. + Once a cluster is running this will no longer be called. + + Parameters + ---------- + clusters : List[Cluster] + The clusters to be checked. + + Returns + ------- + statuses : List[bool] + The status for each cluster. Return False if the cluster has + stopped or failed, True if the cluster is pending start or running. + """ + raise NotImplementedError + + async def do_start_worker(self, worker): + """Start a worker. + + Parameters + ---------- + worker : Worker + Information on the worker to be started. + + Yields + ------ + worker_state : dict + Any state needed for further interactions with this worker. This + should be serializable using ``json.dumps``. If startup occurs in + multiple stages, can iteratively yield state updates to be + checkpointed. If an error occurs at any time, the last yielded + state will be used when calling ``do_stop_worker``. + """ + raise NotImplementedError + + async def do_stop_worker(self, worker): + """Stop a worker. + + Parameters + ---------- + worker : Worker + Information on the worker to be stopped. + """ + raise NotImplementedError + + async def do_check_workers(self, workers): + """Check the status of multiple workers. + + This is periodically called to check the status of pending workers. + Once a worker is running this will no longer be called. + + Parameters + ---------- + workers : List[Worker] + The workers to be checked. + + Returns + ------- + statuses : List[bool] + The status for each worker. Return False if the worker has + stopped or failed, True if the worker is pending start or running. + """ + raise NotImplementedError diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/inprocess.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/inprocess.py new file mode 100644 index 0000000..7b4d97b --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/inprocess.py @@ -0,0 +1,112 @@ +from dask_gateway.scheduler_preload import GatewaySchedulerService, make_gateway_client +from distributed import Scheduler, Security, Worker +from distributed.core import Status +from distributed.utils import TimeoutError + +from .local import UnsafeLocalBackend + +__all__ = ("InProcessBackend",) + + +class InProcessBackend(UnsafeLocalBackend): + """A backend that runs everything in the same process""" + + def get_security(self, cluster): + cert_path, key_path = self.get_tls_paths(cluster) + return Security( + tls_ca_file=cert_path, + tls_scheduler_cert=cert_path, + tls_scheduler_key=key_path, + tls_worker_cert=cert_path, + tls_worker_key=key_path, + ) + + def get_gateway_client(self, cluster): + return make_gateway_client( + cluster_name=cluster.name, api_token=cluster.token, api_url=self.api_url + ) + + def _check_status(self, objs, mapping): + out = [] + for x in objs: + x = mapping.get(x.name) + ok = x is not None and not x.status != Status.closed + out.append(ok) + return out + + async def do_setup(self): + self.schedulers = {} + self.workers = {} + + async def do_start_cluster(self, cluster): + workdir = self.setup_working_directory(cluster) + yield {"workdir": workdir} + + security = self.get_security(cluster) + gateway_client = self.get_gateway_client(cluster) + + self.schedulers[cluster.name] = scheduler = Scheduler( + protocol="tls", + host="127.0.0.1", + port=0, + dashboard_address="127.0.0.1:0", + security=security, + services={ + ("gateway", ":0"): ( + GatewaySchedulerService, + { + "gateway": gateway_client, + "heartbeat_period": self.cluster_heartbeat_period, + "adaptive_period": cluster.config.adaptive_period, + "idle_timeout": cluster.config.idle_timeout, + }, + ) + }, + ) + await scheduler + yield {"workdir": workdir, "started": True} + + async def do_stop_cluster(self, cluster): + scheduler = self.schedulers.pop(cluster.name) + + await scheduler.close() + scheduler.stop() + + workdir = cluster.state.get("workdir") + if workdir is not None: + self.cleanup_working_directory(workdir) + + async def do_check_clusters(self, clusters): + return self._check_status(clusters, self.schedulers) + + async def do_start_worker(self, worker): + security = self.get_security(worker.cluster) + workdir = worker.cluster.state["workdir"] + self.workers[worker.name] = worker = Worker( + worker.cluster.scheduler_address, + nthreads=worker.cluster.config.worker_threads, + memory_limit=0, + security=security, + name=worker.name, + local_directory=workdir, + ) + await worker + yield {"started": True} + + async def do_stop_worker(self, worker): + worker = self.workers.pop(worker.name, None) + if worker is None: + return + try: + await worker.close(timeout=1) + except TimeoutError: + pass + + async def do_check_workers(self, workers): + return self._check_status(workers, self.workers) + + async def worker_status(self, worker_name, worker_state, cluster_state): + worker = self.workers.get(worker_name) + if worker is None: + return False + return not worker.status != Status.closed diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/__init__.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/base.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/base.py new file mode 100644 index 0000000..c85540b --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/base.py @@ -0,0 +1,213 @@ +import asyncio +import json +import os +import pwd +import shutil + +from traitlets import Unicode, default + +from ..base import ClusterConfig +from ..db_base import DBBackendBase + +__all__ = ("JobQueueClusterConfig", "JobQueueBackend") + + +class JobQueueClusterConfig(ClusterConfig): + worker_setup = Unicode( + "", help="Script to run before dask worker starts.", config=True + ) + + scheduler_setup = Unicode( + "", help="Script to run before dask scheduler starts.", config=True + ) + + staging_directory = Unicode( + "{home}/.dask-gateway/", + help=""" + The staging directory for storing files before the job starts. + + A subdirectory will be created for each new cluster which will store + temporary files for that cluster. On cluster shutdown the subdirectory + will be removed. + + This field can be a template, which receives the following fields: + + - home (the user's home directory) + - username (the user's name) + """, + config=True, + ) + + +class JobQueueBackend(DBBackendBase): + """A base cluster manager for deploying Dask on a jobqueue cluster.""" + + dask_gateway_jobqueue_launcher = Unicode( + help="The path to the dask-gateway-jobqueue-launcher executable", config=True + ) + + @default("dask_gateway_jobqueue_launcher") + def _default_launcher_path(self): + return ( + shutil.which("dask-gateway-jobqueue-launcher") + or "dask-gateway-jobqueue-launcher" + ) + + submit_command = Unicode(help="The path to the job submit command", config=True) + + cancel_command = Unicode(help="The path to the job cancel command", config=True) + + status_command = Unicode(help="The path to the job status command", config=True) + + def get_submit_cmd_env_stdin(self, cluster, worker=None): + raise NotImplementedError + + def get_stop_cmd_env(self, job_id): + raise NotImplementedError + + def get_status_cmd_env(self, job_ids): + raise NotImplementedError + + def parse_job_id(self, stdout): + raise NotImplementedError + + def parse_job_states(self, stdout): + raise NotImplementedError + + def get_staging_directory(self, cluster): + staging_dir = cluster.config.staging_directory.format( + home=pwd.getpwnam(cluster.username).pw_dir, username=cluster.username + ) + return os.path.join(staging_dir, cluster.name) + + def get_tls_paths(self, cluster): + """Get the absolute paths to the tls cert and key files.""" + staging_dir = self.get_staging_directory(cluster) + cert_path = os.path.join(staging_dir, "dask.crt") + key_path = os.path.join(staging_dir, "dask.pem") + return cert_path, key_path + + async def do_as_user(self, user, action, **kwargs): + cmd = ["sudo", "-nHu", user, self.dask_gateway_jobqueue_launcher] + kwargs["action"] = action + proc = await asyncio.create_subprocess_exec( + *cmd, + env={}, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate(json.dumps(kwargs).encode("utf8")) + stdout = stdout.decode("utf8", "replace") + stderr = stderr.decode("utf8", "replace") + + if proc.returncode != 0: + raise Exception( + "Error running `dask-gateway-jobqueue-launcher`\n" + " returncode: %d\n" + " stdout: %s\n" + " stderr: %s" % (proc.returncode, stdout, stderr) + ) + result = json.loads(stdout) + if not result["ok"]: + raise Exception(result["error"]) + return result["returncode"], result["stdout"], result["stderr"] + + async def start_job(self, username, cmd, env, stdin, staging_dir=None, files=None): + code, stdout, stderr = await self.do_as_user( + user=username, + action="start", + cmd=cmd, + env=env, + stdin=stdin, + staging_dir=staging_dir, + files=files, + ) + if code != 0: + raise Exception( + ( + "Failed to submit job to batch system\n" + " exit_code: %d\n" + " stdout: %s\n" + " stderr: %s" + ) + % (code, stdout, stderr) + ) + return self.parse_job_id(stdout) + + async def stop_job(self, username, job_id, staging_dir=None): + cmd, env = self.get_stop_cmd_env(job_id) + + code, stdout, stderr = await self.do_as_user( + user=username, action="stop", cmd=cmd, env=env, staging_dir=staging_dir + ) + if code != 0 and "Job has finished" not in stderr: + raise Exception("Failed to stop job_id %s" % job_id) + + async def check_jobs(self, job_ids): + if not job_ids: + return {} + self.log.debug("Checking status of %d jobs", len(job_ids)) + cmd, env = self.get_status_cmd_env(job_ids) + proc = await asyncio.create_subprocess_exec( + *cmd, + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + stdout = stdout.decode("utf8", "replace") + if proc.returncode != 0: + stderr = stderr.decode("utf8", "replace") + self.log.warning( + "Job status check failed with returncode %d, stderr: %s", + proc.returncode, + stderr, + ) + raise Exception("Job status check failed") + return self.parse_job_states(stdout) + + async def do_start_cluster(self, cluster): + cmd, env, stdin = self.get_submit_cmd_env_stdin(cluster) + staging_dir = self.get_staging_directory(cluster) + files = { + "dask.pem": cluster.tls_key.decode("utf8"), + "dask.crt": cluster.tls_cert.decode("utf8"), + } + job_id = await self.start_job( + cluster.username, cmd, env, stdin, staging_dir=staging_dir, files=files + ) + self.log.info("Job %s submitted for cluster %s", job_id, cluster.name) + yield {"job_id": job_id, "staging_dir": staging_dir} + + async def do_stop_cluster(self, cluster): + job_id = cluster.state.get("job_id") + if job_id is not None: + staging_dir = cluster.state["staging_dir"] + await self.stop_job(cluster.username, job_id, staging_dir=staging_dir) + + async def do_start_worker(self, worker): + cmd, env, stdin = self.get_submit_cmd_env_stdin(worker.cluster, worker) + job_id = await self.start_job(worker.cluster.username, cmd, env, stdin) + self.log.info("Job %s submitted for worker %s", job_id, worker.name) + yield {"job_id": job_id} + + async def do_stop_worker(self, worker): + job_id = worker.state.get("job_id") + if job_id is not None: + await self.stop_job(worker.cluster.username, job_id) + + async def _do_check(self, objs): + id_map = {} + for x in objs: + job_id = x.state.get("job_id") + if job_id is not None: + id_map[x.name] = job_id + states = await self.check_jobs(list(id_map.values())) + return [states.get(id_map.get(x.name), False) for x in objs] + + async def do_check_clusters(self, clusters): + return await self._do_check(clusters) + + async def do_check_workers(self, workers): + return await self._do_check(workers) diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/launcher.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/launcher.py new file mode 100644 index 0000000..6926ab3 --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/launcher.py @@ -0,0 +1,87 @@ +import json +import os +import shutil +import subprocess +import sys + + +def finish(**kwargs): + json.dump(kwargs, sys.stdout) + sys.stdout.flush() + + +def run_command(cmd, env, stdin=None): + if stdin is not None: + stdin = stdin.encode("utf8") + STDIN = subprocess.PIPE + else: + STDIN = None + + proc = subprocess.Popen( + cmd, + env=env, + cwd=os.path.expanduser("~"), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=STDIN, + ) + + stdout, stderr = proc.communicate(stdin) + + finish( + ok=True, + returncode=proc.returncode, + stdout=stdout.decode("utf8", "replace"), + stderr=stderr.decode("utf8", "replace"), + ) + + +def start(cmd, env, stdin=None, staging_dir=None, files=None): + if staging_dir: + try: + os.makedirs(staging_dir, mode=0o700, exist_ok=False) + for name, value in files.items(): + with open(os.path.join(staging_dir, name), "w") as f: + f.write(value) + except Exception as exc: + finish( + ok=False, + error=f"Error setting up staging directory {staging_dir}: {exc}", + ) + return + run_command(cmd, env, stdin=stdin) + + +def stop(cmd, env, staging_dir=None): + if staging_dir: + if not os.path.exists(staging_dir): + return + try: + shutil.rmtree(staging_dir) + except Exception as exc: + finish( + ok=False, + error=f"Error removing staging directory {staging_dir}: {exc}", + ) + return + run_command(cmd, env) + + +def main(): + try: + kwargs = json.load(sys.stdin) + except ValueError as exc: + finish(ok=False, error=str(exc)) + return + + action = kwargs.pop("action", None) + if action == "start": + start(**kwargs) + elif action == "stop": + stop(**kwargs) + else: + finish(ok=False, error="Valid actions are 'start' and 'stop'") + + +if __name__ == "__main__": + main() diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/slurm.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/slurm.py new file mode 100644 index 0000000..7b9607c --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/jobqueue/slurm.py @@ -0,0 +1,120 @@ +import math +import os +import shutil + +from traitlets import Unicode, default + +from dask_gateway_server.traitlets import Type +from .base import JobQueueBackend, JobQueueClusterConfig + +__all__ = ("SlurmBackend", "SlurmClusterConfig") + + +def slurm_format_memory(n): + """Format memory in bytes for use with slurm.""" + if n >= 10 * (1024 ** 3): + return "%dG" % math.ceil(n / (1024 ** 3)) + if n >= 10 * (1024 ** 2): + return "%dM" % math.ceil(n / (1024 ** 2)) + if n >= 10 * 1024: + return "%dK" % math.ceil(n / 1024) + return "1K" + + +class SlurmClusterConfig(JobQueueClusterConfig): + """Dask cluster configuration options when running on SLURM""" + + partition = Unicode("", help="The partition to submit jobs to.", config=True) + + qos = Unicode("", help="QOS string associated with each job.", config=True) + + account = Unicode("", help="Account string associated with each job.", config=True) + + +class SlurmBackend(JobQueueBackend): + """A backend for deploying Dask on a Slurm cluster.""" + + cluster_config_class = Type( + "dask_gateway_server.backends.jobqueue.slurm.SlurmClusterConfig", + klass="dask_gateway_server.backends.base.ClusterConfig", + help="The cluster config class to use", + config=True, + ) + + @default("submit_command") + def _default_submit_command(self): + return shutil.which("sbatch") or "sbatch" + + @default("cancel_command") + def _default_cancel_command(self): + return shutil.which("scancel") or "scancel" + + @default("status_command") + def _default_status_command(self): + return shutil.which("squeue") or "squeue" + + def get_submit_cmd_env_stdin(self, cluster, worker=None): + cmd = [self.submit_command, "--parsable"] + cmd.append("--job-name=dask-gateway") + if cluster.config.partition: + cmd.append("--partition=" + cluster.config.partition) + if cluster.config.account: + cmd.append("--account=" + cluster.config.account) + if cluster.config.qos: + cmd.extend("--qos=" + cluster.config.qos) + + if worker: + cpus = cluster.config.worker_cores + mem = slurm_format_memory(cluster.config.worker_memory) + log_file = "dask-worker-%s.log" % worker.name + script = "\n".join( + [ + "#!/bin/sh", + cluster.config.worker_setup, + " ".join(self.get_worker_command(cluster, worker.name)), + ] + ) + env = self.get_worker_env(cluster) + else: + cpus = cluster.config.scheduler_cores + mem = slurm_format_memory(cluster.config.scheduler_memory) + log_file = "dask-scheduler-%s.log" % cluster.name + script = "\n".join( + [ + "#!/bin/sh", + cluster.config.scheduler_setup, + " ".join(self.get_scheduler_command(cluster)), + ] + ) + env = self.get_scheduler_env(cluster) + + staging_dir = self.get_staging_directory(cluster) + + cmd.extend( + [ + "--chdir=" + staging_dir, + "--output=" + os.path.join(staging_dir, log_file), + "--cpus-per-task=%d" % cpus, + "--mem=%s" % mem, + "--export=%s" % (",".join(sorted(env))), + ] + ) + + return cmd, env, script + + def get_stop_cmd_env(self, job_id): + return [self.cancel_command, job_id], {} + + def get_status_cmd_env(self, job_ids): + cmd = [self.status_command, "-h", "--job=%s" % ",".join(job_ids), "-o", "%i %t"] + return cmd, {} + + def parse_job_states(self, stdout): + states = {} + for l in stdout.splitlines(): + job_id, state = l.split() + states[job_id] = state in ("R", "CG", "PD", "CF") + return states + + def parse_job_id(self, stdout): + return stdout.strip() diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/local.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/local.py new file mode 100644 index 0000000..6af5620 --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/local.py @@ -0,0 +1,318 @@ +import asyncio +import errno +import functools +import grp +import os +import pwd +import shutil +import signal +import sys +import tempfile + +from traitlets import Integer, List, Unicode + +from dask_gateway_server.traitlets import Type +from .base import ClusterConfig +from .db_base import DBBackendBase + +__all__ = ("LocalClusterConfig", "LocalBackend", "UnsafeLocalBackend") + + +class LocalClusterConfig(ClusterConfig): + """Dask cluster configuration options when running as local processes""" + + pass + + +def _signal(pid, sig): + """Send given signal to a pid. + + Returns True if the process still exists, False otherwise.""" + try: + os.kill(pid, sig) + except OSError as e: + if e.errno == errno.ESRCH: + return False + raise + return True + + +def is_running(pid): + return _signal(pid, 0) + + +async def wait_is_shutdown(pid, timeout=10): + """Wait for a pid to shutdown, using exponential backoff""" + pause = 0.1 + while timeout >= 0: + if not _signal(pid, 0): + return True + await asyncio.sleep(pause) + timeout -= pause + pause *= 2 + return False + + +@functools.lru_cache +def getpwnam(username): + return pwd.getpwnam(username) + + +class LocalBackend(DBBackendBase): + """A cluster backend that launches local processes. + + Requires super-user permissions in order to run processes for the + requesting username. + """ + + cluster_config_class = Type( + "dask_gateway_server.backends.local.LocalClusterConfig", + klass="dask_gateway_server.backends.base.ClusterConfig", + help="The cluster config class to use", + config=True, + ) + + sigint_timeout = Integer( + 10, + help=""" + Seconds to wait for process to stop after SIGINT. + + If the process has not stopped after this time, a SIGTERM is sent. + """, + config=True, + ) + + sigterm_timeout = Integer( + 5, + help=""" + Seconds to wait for process to stop after SIGTERM. + + If the process has not stopped after this time, a SIGKILL is sent. + """, + config=True, + ) + + sigkill_timeout = Integer( + 5, + help=""" + Seconds to wait for process to stop after SIGKILL. + + If the process has not stopped after this time, a warning is logged and + the process is deemed a zombie process. + """, + config=True, + ) + + clusters_directory = Unicode( + help=""" + The base directory for cluster working directories. + + A subdirectory will be created for each new cluster which will serve as + the working directory for that cluster. On cluster shutdown the + subdirectory will be removed. + + If not specified, a temporary directory will be used for each cluster. + """, + config=True, + ) + + inherited_environment = List( + [ + "PATH", + "PYTHONPATH", + "CONDA_ROOT", + "CONDA_DEFAULT_ENV", + "VIRTUAL_ENV", + "LANG", + "LC_ALL", + ], + help=""" + Whitelist of environment variables for the scheduler and worker + processes to inherit from the Dask-Gateway process. + """, + config=True, + ) + + default_host = "127.0.0.1" + + def set_file_permissions(self, paths, username): + pwnam = getpwnam(username) + for p in paths: + os.chown(p, pwnam.pw_uid, pwnam.pw_gid) + + def make_preexec_fn(self, cluster): # pragma: nocover + # Borrowed and modified from jupyterhub/spawner.py + pwnam = getpwnam(cluster.username) + uid = pwnam.pw_uid + gid = pwnam.pw_gid + groups = [g.gr_gid for g in grp.getgrall() if cluster.username in g.gr_mem] + workdir = cluster.state["workdir"] + + def preexec(): + os.setgid(gid) + try: + os.setgroups(groups) + except Exception as e: + print("Failed to set groups %s" % e, file=sys.stderr) + os.setuid(uid) + os.chdir(workdir) + + return preexec + + def setup_working_directory(self, cluster): # pragma: nocover + if self.clusters_directory: + workdir = os.path.join(self.clusters_directory, cluster.name) + else: + workdir = tempfile.mkdtemp(prefix="dask", suffix=cluster.name) + certsdir = self.get_certs_directory(workdir) + logsdir = self.get_logs_directory(workdir) + + paths = [workdir, certsdir, logsdir] + for path in paths: + os.makedirs(path, 0o700, exist_ok=True) + + cert_path, key_path = self._get_tls_paths(workdir) + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + for path, data in [(cert_path, cluster.tls_cert), (key_path, cluster.tls_key)]: + with os.fdopen(os.open(path, flags, 0o600), "wb") as fil: + fil.write(data) + paths.extend(path) + + self.set_file_permissions(paths, cluster.username) + + self.log.debug( + "Working directory %s for cluster %s created", workdir, cluster.name + ) + return workdir + + def cleanup_working_directory(self, workdir): + if os.path.exists(workdir): + try: + shutil.rmtree(workdir) + self.log.debug("Working directory %s removed", workdir) + except Exception: # pragma: nocover + self.log.warn("Failed to remove working directory %r", workdir) + + def get_certs_directory(self, workdir): + return os.path.join(workdir, ".certs") + + def get_logs_directory(self, workdir): + return os.path.join(workdir, "logs") + + def _get_tls_paths(self, workdir): + certsdir = self.get_certs_directory(workdir) + cert_path = os.path.join(certsdir, "dask.crt") + key_path = os.path.join(certsdir, "dask.pem") + return cert_path, key_path + + def get_tls_paths(self, cluster): + return self._get_tls_paths(cluster.state["workdir"]) + + def get_env(self, cluster): + env = super().get_env(cluster) + for key in self.inherited_environment: + if key in os.environ: + env[key] = os.environ[key] + env["USER"] = cluster.username + return env + + async def start_process(self, cluster, cmd, env, name): + workdir = cluster.state["workdir"] + logsdir = self.get_logs_directory(workdir) + log_path = os.path.join(logsdir, name + ".log") + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + fd = None + try: + fd = os.open(log_path, flags, 0o755) + proc = await asyncio.create_subprocess_exec( + *cmd, + preexec_fn=self.make_preexec_fn(cluster), + start_new_session=True, + env=env, + stdout=fd, + stderr=asyncio.subprocess.STDOUT, + ) + finally: + if fd is not None: + os.close(fd) + return proc.pid + + async def stop_process(self, pid): + methods = [ + ("SIGINT", signal.SIGINT, self.sigint_timeout), + ("SIGTERM", signal.SIGTERM, self.sigterm_timeout), + ("SIGKILL", signal.SIGKILL, self.sigkill_timeout), + ] + + for msg, sig, timeout in methods: + self.log.debug("Sending %s to process %d", msg, pid) + _signal(pid, sig) + if await wait_is_shutdown(pid, timeout): + return + + if is_running(pid): + # all attempts failed, zombie process + self.log.warn("Failed to stop process %d", pid) + + async def do_start_cluster(self, cluster): + workdir = self.setup_working_directory(cluster) + yield {"workdir": workdir} + + pid = await self.start_process( + cluster, + self.get_scheduler_command(cluster), + self.get_scheduler_env(cluster), + "scheduler", + ) + yield {"workdir": workdir, "pid": pid} + + async def do_stop_cluster(self, cluster): + pid = cluster.state.get("pid") + if pid is not None: + await self.stop_process(pid) + + workdir = cluster.state.get("workdir") + if workdir is not None: + self.cleanup_working_directory(workdir) + + def _check_status(self, o): + pid = o.state.get("pid") + return pid is not None and is_running(pid) + + async def do_check_clusters(self, clusters): + return [self._check_status(c) for c in clusters] + + async def do_start_worker(self, worker): + cmd = self.get_worker_command(worker.cluster, worker.name) + env = self.get_worker_env(worker.cluster) + pid = await self.start_process( + worker.cluster, cmd, env, "worker-%s" % worker.name + ) + yield {"pid": pid} + + async def do_stop_worker(self, worker): + pid = worker.state.get("pid") + if pid is not None: + await self.stop_process(pid) + + async def do_check_workers(self, workers): + return [self._check_status(w) for w in workers] + + +class UnsafeLocalBackend(LocalBackend): + """A version of LocalBackend that doesn't set permissions. + + FOR TESTING ONLY! This provides no user separations - clusters run with the + same level of permission as the gateway. + """ + + def make_preexec_fn(self, cluster): + workdir = cluster.state["workdir"] + + def preexec(): # pragma: nocover + os.chdir(workdir) + + return preexec + + def set_file_permissions(self, paths, username): + pass diff --git a/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/yarn.py b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/yarn.py new file mode 100644 index 0000000..ea16638 --- /dev/null +++ b/aws_pcluster_dask_gateway/dask_gateway_extensions/backends/yarn.py @@ -0,0 +1,285 @@ +from asyncio import get_running_loop +from collections import defaultdict +from tempfile import NamedTemporaryFile + +try: + import skein +except ImportError: + raise ImportError( + "'%s.YarnBackend' requires 'skein' as a dependency. " + "To install required dependencies, use:\n" + " $ pip install dask-gateway-server[yarn]\n" + "or\n" + " $ conda install dask-gateway-server-yarn -c conda-forge\n" % __name__ + ) + +from traitlets import Dict, Integer, Unicode + +from dask_gateway_server.traitlets import Type +from dask_gateway_server.utils import LRUCache +from .base import ClusterConfig +from .db_base import DBBackendBase + +__all__ = ("YarnClusterConfig", "YarnBackend") + + +class YarnClusterConfig(ClusterConfig): + """Dask cluster configuration options when running on Hadoop/YARN""" + + queue = Unicode( + "default", help="The YARN queue to submit applications under", config=True + ) + + localize_files = Dict( + help=""" + Extra files to distribute to both the worker and scheduler containers. + + This is a mapping from ``local-name`` to ``resource``. Resource paths + can be local, or in HDFS (prefix with ``hdfs://...`` if so). If an + archive (``.tar.gz`` or ``.zip``), the resource will be unarchived as + directory ``local-name``. For finer control, resources can also be + specified as ``skein.File`` objects, or their ``dict`` equivalents. + + This can be used to distribute conda/virtual environments by + configuring the following: + + .. code:: + + c.YarnClusterConfig.localize_files = { + 'environment': { + 'source': 'hdfs:///path/to/archived/environment.tar.gz', + 'visibility': 'public' + } + } + c.YarnClusterConfig.scheduler_setup = 'source environment/bin/activate' + c.YarnClusterConfig.worker_setup = 'source environment/bin/activate' + + These archives are usually created using either ``conda-pack`` or + ``venv-pack``. For more information on distributing files, see + https://jcristharif.com/skein/distributing-files.html. + """, + config=True, + ) + + worker_setup = Unicode( + "", help="Script to run before dask worker starts.", config=True + ) + + scheduler_setup = Unicode( + "", help="Script to run before dask scheduler starts.", config=True + ) + + +class YarnBackend(DBBackendBase): + """A cluster backend for managing dask clusters on Hadoop/YARN.""" + + cluster_config_class = Type( + "dask_gateway_server.backends.yarn.YarnClusterConfig", + klass="dask_gateway_server.backends.base.ClusterConfig", + help="The cluster config class to use", + config=True, + ) + + principal = Unicode( + None, + help="Kerberos principal for Dask Gateway user", + allow_none=True, + config=True, + ) + + keytab = Unicode( + None, + help="Path to kerberos keytab for Dask Gateway user", + allow_none=True, + config=True, + ) + + app_client_cache_max_size = Integer( + 10, + help=""" + The max size of the cache for application clients. + + A larger cache will result in improved performance, but will also use + more resources. + """, + config=True, + ) + + def async_apply(self, f, *args, **kwargs): + return get_running_loop().run_in_executor(None, lambda: f(*args, **kwargs)) + + def _get_security(self, cluster): + return skein.Security(cert_bytes=cluster.tls_cert, key_bytes=cluster.tls_key) + + async def _get_app_client(self, cluster): + out = self.app_client_cache.get(cluster.name) + if out is None: + app_id = cluster.state["app_id"] + security = self._get_security(cluster) + if cluster.name not in self.app_address_cache: + # Lookup and cache the application address + report = self.skein_client.application_report(app_id) + if report.state != "RUNNING": # pragma: nocover + raise ValueError("Application %s is not running" % app_id) + app_address = "%s:%d" % (report.host, report.port) + self.app_address_cache[cluster.name] = app_address + app_address = self.app_address_cache[cluster.name] + out = skein.ApplicationClient(app_address, app_id, security=security) + self.app_client_cache.put(cluster.name, out) + return out + + def worker_nthreads_memory_limit_args(self, cluster): + return "$SKEIN_RESOURCE_VCORES", "${SKEIN_RESOURCE_MEMORY}MiB" + + def _build_specification(self, cluster, cert_path, key_path): + files = { + k: skein.File.from_dict(v) if isinstance(v, dict) else v + for k, v in cluster.config.localize_files.items() + } + + files["dask.crt"] = cert_path + files["dask.pem"] = key_path + + scheduler_cmd = " ".join(self.get_scheduler_command(cluster)) + worker_cmd = " ".join( + self.get_worker_command( + cluster, + worker_name="$DASK_GATEWAY_WORKER_NAME", + scheduler_address="$DASK_GATEWAY_SCHEDULER_ADDRESS", + ) + ) + scheduler_script = f"{cluster.config.scheduler_setup}\n{scheduler_cmd}" + worker_script = f"{cluster.config.worker_setup}\n{worker_cmd}" + + master = skein.Master( + security=self._get_security(cluster), + resources=skein.Resources( + memory="%d b" % cluster.config.scheduler_memory, + vcores=cluster.config.scheduler_cores, + ), + files=files, + env=self.get_scheduler_env(cluster), + script=scheduler_script, + ) + + services = { + "dask.worker": skein.Service( + resources=skein.Resources( + memory="%d b" % cluster.config.worker_memory, + vcores=cluster.config.worker_cores, + ), + instances=0, + max_restarts=0, + allow_failures=True, + files=files, + env=self.get_worker_env(cluster), + script=worker_script, + ) + } + + return skein.ApplicationSpec( + name="dask-gateway", + queue=cluster.config.queue, + user=cluster.username, + master=master, + services=services, + ) + + supports_bulk_shutdown = True + + async def do_setup(self): + self.skein_client = await self.async_apply( + skein.Client, + principal=self.principal, + keytab=self.keytab, + security=skein.Security.new_credentials(), + ) + + self.app_client_cache = LRUCache(self.app_client_cache_max_size) + self.app_address_cache = {} + + async def do_cleanup(self): + self.skein_client.close() + + async def do_start_cluster(self, cluster): + with NamedTemporaryFile() as cert_fil, NamedTemporaryFile() as key_fil: + cert_fil.write(cluster.tls_cert) + cert_fil.file.flush() + key_fil.write(cluster.tls_key) + key_fil.file.flush() + spec = self._build_specification(cluster, cert_fil.name, key_fil.name) + app_id = await self.async_apply(self.skein_client.submit, spec) + + yield {"app_id": app_id} + + async def do_stop_cluster(self, cluster): + app_id = cluster.state.get("app_id") + if app_id is None: + return + + await self.async_apply(self.skein_client.kill_application, app_id) + # Remove cluster from caches + self.app_client_cache.discard(cluster.name) + self.app_address_cache.pop(cluster.name, None) + + async def do_check_clusters(self, clusters): + results = [] + for cluster in clusters: + app_id = cluster.state.get("app_id") + if app_id is None: + return False + report = await self.async_apply( + self.skein_client.application_report, app_id + ) + ok = str(report.state) not in {"FAILED", "KILLED", "FINISHED"} + results.append(ok) + return results + + async def do_start_worker(self, worker): + app = await self._get_app_client(worker.cluster) + container = await self.async_apply( + app.add_container, + "dask.worker", + env={ + "DASK_GATEWAY_WORKER_NAME": worker.name, + "DASK_GATEWAY_SCHEDULER_ADDRESS": worker.cluster.scheduler_address, + }, + ) + yield {"container_id": container.id} + + async def do_stop_worker(self, worker): + container_id = worker.state.get("container_id") + if container_id is None: + return + + app = await self._get_app_client(worker.cluster) + try: + await self.async_apply(app.kill_container, container_id) + except ValueError: + pass + + async def do_check_workers(self, workers): + grouped = defaultdict(list) + for w in workers: + grouped[w.cluster].append(w) + + results = {} + for cluster, workers in grouped.items(): + app = await self._get_app_client(cluster) + try: + containers = await self.async_apply( + app.get_containers, services=("dask.worker",) + ) + active = {c.id for c in containers} + results.update( + {w.name: w.state.get("container_id") in active for w in workers} + ) + except Exception as exc: + self.log.debug( + "Error getting worker statuses for cluster %s", + cluster.name, + exc_info=exc, + ) + results.update({w.name: False for w in workers}) + + return [results[w.name] for w in workers] diff --git a/docs/conf.py b/docs/conf.py index 6f27c14..1eef377 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,8 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) import aws_pcluster_dask_gateway @@ -31,22 +32,22 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'aws_pcluster_dask_gateway' +project = "aws_pcluster_dask_gateway" copyright = "2022, Jillian Rowe" author = "Jillian Rowe" @@ -69,10 +70,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -83,7 +84,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a # theme further. For a list of options available for each theme, see the @@ -94,13 +95,13 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # -- Options for HTMLHelp output --------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'aws_pcluster_dask_gatewaydoc' +htmlhelp_basename = "aws_pcluster_dask_gatewaydoc" # -- Options for LaTeX output ------------------------------------------ @@ -109,15 +110,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -127,9 +125,13 @@ # (source start file, target name, title, author, documentclass # [howto, manual, or own class]). latex_documents = [ - (master_doc, 'aws_pcluster_dask_gateway.tex', - 'aws_pcluster_dask_gateway Documentation', - 'Jillian Rowe', 'manual'), + ( + master_doc, + "aws_pcluster_dask_gateway.tex", + "aws_pcluster_dask_gateway Documentation", + "Jillian Rowe", + "manual", + ), ] @@ -138,9 +140,13 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'aws_pcluster_dask_gateway', - 'aws_pcluster_dask_gateway Documentation', - [author], 1) + ( + master_doc, + "aws_pcluster_dask_gateway", + "aws_pcluster_dask_gateway Documentation", + [author], + 1, + ) ] @@ -150,13 +156,13 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'aws_pcluster_dask_gateway', - 'aws_pcluster_dask_gateway Documentation', - author, - 'aws_pcluster_dask_gateway', - 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "aws_pcluster_dask_gateway", + "aws_pcluster_dask_gateway Documentation", + author, + "aws_pcluster_dask_gateway", + "One line description of project.", + "Miscellaneous", + ), ] - - - diff --git a/requirements.txt b/requirements.txt index 1986770..69caf19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -dask-gateway-server[local] -dask-gateway-server[jobqueue] +dask-gateway-server datasize humanize +pydantic>2 aws-pcluster-helpers>=3.5 typer # conda diff --git a/setup.py b/setup.py index d66e8ae..e18a380 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,10 @@ from setuptools import setup, find_packages import versioneer -with open('README.rst') as readme_file: +with open("README.rst") as readme_file: readme = readme_file.read() -with open('HISTORY.rst') as history_file: +with open("HISTORY.rst") as history_file: history = history_file.read() with open("requirements.txt", "r") as fh: @@ -18,34 +18,36 @@ setup( author="Jillian Rowe", - author_email='jillian@dabbleofdevops.com', - python_requires='>=3.6', + author_email="jillian@dabbleofdevops.com", + python_requires=">=3.6", classifiers=[ - 'Development Status :: 2 - Pre-Alpha', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Apache Software License', - 'Natural Language :: English', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], description="A helper function to configure Dask Gateway to play nice with AWS PCluster", entry_points={ - 'console_scripts': [ - 'aws_pcluster_dask_gateway=aws_pcluster_dask_gateway.cli:main', + "console_scripts": [ + "aws_pcluster_dask_gateway=aws_pcluster_dask_gateway.cli:main", ], }, install_requires=requirements, license="Apache Software License 2.0", - long_description=readme + '\n\n' + history, + long_description=readme + "\n\n" + history, include_package_data=True, - keywords='aws_pcluster_dask_gateway', - name='aws_pcluster_dask_gateway', - packages=find_packages(include=['aws_pcluster_dask_gateway', 'aws_pcluster_dask_gateway.*']), - test_suite='tests', + keywords="aws_pcluster_dask_gateway", + name="aws_pcluster_dask_gateway", + packages=find_packages( + include=["aws_pcluster_dask_gateway", "aws_pcluster_dask_gateway.*"] + ), + test_suite="tests", tests_require=test_requirements, - url='https://github.com/dabble-of-devops-bioanalyze/aws_pcluster_dask_gateway', + url="https://github.com/dabble-of-devops-bioanalyze/aws_pcluster_dask_gateway", zip_safe=False, # version="0.1.0", # version="3.5.1", diff --git a/tests/test_aws_pcluster_dask_gateway.py b/tests/test_aws_pcluster_dask_gateway.py index 4c9808f..a5cfffc 100644 --- a/tests/test_aws_pcluster_dask_gateway.py +++ b/tests/test_aws_pcluster_dask_gateway.py @@ -24,7 +24,6 @@ ENV_INSTANCE_TYPES_DATA_FILE, ENV_INSTANCE_TYPE_MAPPINGS_FILE, ) -from aws_pcluster_helpers.models.sinfo import SInfoTable, SinfoRow import yaml import json import os @@ -33,6 +32,7 @@ from aws_pcluster_helpers.commands import cli_sinfo from aws_pcluster_helpers.commands import cli_gen_nxf_slurm_config +from aws_pcluster_helpers.models.sinfo import SInfoTable, SinfoRow instance_types_data_file = os.path.join( os.path.dirname(__file__), "instance-types-data.json" @@ -40,15 +40,16 @@ instance_type_mapping_file = os.path.join( os.path.dirname(__file__), "instance_name_type_mappings.json" ) -pcluster_config_file = os.path.join( - os.path.dirname(__file__), "pcluster_config.yml" -) +pcluster_config_file = os.path.join(os.path.dirname(__file__), "pcluster_config.yml") + os.environ[ENV_INSTANCE_TYPE_MAPPINGS_FILE] = instance_type_mapping_file os.environ[ENV_INSTANCE_TYPES_DATA_FILE] = instance_types_data_file os.environ[ENV_PCLUSTER_CONFIG_FILE] = pcluster_config_file logger = setup_logger(logger_name="tests", log_level="DEBUG") +from aws_pcluster_dask_gateway import DaskGatewaySlurmConfig + def test_files(): assert os.path.exists(instance_type_mapping_file) @@ -56,13 +57,11 @@ def test_files(): assert os.path.exists(instance_types_data_file) -def test_sinfo(): - sinfo = SInfoTable() - table = sinfo.get_table() - console = Console() - console.print(table) - - -def test_load_pcluster_config(): - pcluster_config = PClusterConfig.from_yaml(pcluster_config_file) - assert pcluster_config +def test_dask_gateway(): + pcluster_config_files = PClusterConfigFiles( + pcluster_config_file=pcluster_config_file, + instance_types_data_file=instance_types_data_file, + instance_type_mapping_file=instance_type_mapping_file, + ) + dask_gateway_options = DaskGatewaySlurmConfig(pcluster_config_files=pcluster_config_files) + assert dask_gateway_options diff --git a/versioneer.py b/versioneer.py index 1e3753e..de97d90 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,3 @@ - # Version: 0.29 """The Versioneer - like a rocketeer, but for versions. @@ -367,11 +366,13 @@ def get_root() -> str: or os.path.exists(pyproject_toml) or os.path.exists(versioneer_py) ): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -384,8 +385,10 @@ def get_root() -> str: me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(my_path), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py) + ) except NameError: pass return root @@ -403,9 +406,9 @@ def get_config_from_root(root: str) -> VersioneerConfig: section: Union[Dict[str, Any], configparser.SectionProxy, None] = None if pyproject_toml.exists() and have_tomllib: try: - with open(pyproject_toml, 'rb') as fobj: + with open(pyproject_toml, "rb") as fobj: pp = tomllib.load(fobj) - section = pp['tool']['versioneer'] + section = pp["tool"]["versioneer"] except (tomllib.TOMLDecodeError, KeyError) as e: print(f"Failed to load config from {pyproject_toml}: {e}") print("Try to load it from setup.cfg") @@ -422,7 +425,7 @@ def get_config_from_root(root: str) -> VersioneerConfig: # `None` values elsewhere where it matters cfg = VersioneerConfig() - cfg.VCS = section['VCS'] + cfg.VCS = section["VCS"] cfg.style = section.get("style", "") cfg.versionfile_source = cast(str, section.get("versionfile_source")) cfg.versionfile_build = section.get("versionfile_build") @@ -450,10 +453,12 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" HANDLERS.setdefault(vcs, {})[method] = f return f + return decorate @@ -480,10 +485,14 @@ def run_command( try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError as e: if e.errno == errno.ENOENT: @@ -505,7 +514,9 @@ def run_command( return stdout, process.returncode -LONG_VERSION_PY['git'] = r''' +LONG_VERSION_PY[ + "git" +] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -1250,7 +1261,7 @@ def git_versions_from_keywords( # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1259,7 +1270,7 @@ def git_versions_from_keywords( # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1267,32 +1278,36 @@ def git_versions_from_keywords( for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. @@ -1311,8 +1326,7 @@ def git_pieces_from_vcs( env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1320,10 +1334,19 @@ def git_pieces_from_vcs( # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1338,8 +1361,7 @@ def git_pieces_from_vcs( pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -1379,17 +1401,16 @@ def git_pieces_from_vcs( dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1398,10 +1419,12 @@ def git_pieces_from_vcs( if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1479,15 +1502,21 @@ def versions_from_parentdir( for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1516,11 +1545,13 @@ def versions_from_file(filename: str) -> Dict[str, Any]: contents = f.read() except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1528,8 +1559,7 @@ def versions_from_file(filename: str) -> Dict[str, Any]: def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1561,8 +1591,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str: rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1591,8 +1620,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str: rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1753,11 +1781,13 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str: def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1781,9 +1811,13 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1806,8 +1840,9 @@ def get_versions(verbose: bool = False) -> Dict[str, Any]: handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1861,9 +1896,13 @@ def get_versions(verbose: bool = False) -> Dict[str, Any]: if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version() -> str: @@ -1916,6 +1955,7 @@ def run(self) -> None: print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in setuptools @@ -1937,8 +1977,8 @@ def run(self) -> None: # but the build_py command is not expected to copy any files. # we override different "build_py" commands for both environments - if 'build_py' in cmds: - _build_py: Any = cmds['build_py'] + if "build_py" in cmds: + _build_py: Any = cmds["build_py"] else: from setuptools.command.build_py import build_py as _build_py @@ -1955,14 +1995,14 @@ def run(self) -> None: # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py - if 'build_ext' in cmds: - _build_ext: Any = cmds['build_ext'] + if "build_ext" in cmds: + _build_ext: Any = cmds["build_ext"] else: from setuptools.command.build_ext import build_ext as _build_ext @@ -1982,19 +2022,22 @@ def run(self) -> None: # it with an updated value if not cfg.versionfile_build: return - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) if not os.path.exists(target_versionfile): - print(f"Warning: {target_versionfile} does not exist, skipping " - "version update. This can happen if you are running build_ext " - "without first running build_py.") + print( + f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py." + ) return print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe # type: ignore + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -2015,17 +2058,21 @@ def run(self) -> None: os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? try: from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: @@ -2044,18 +2091,22 @@ def run(self) -> None: os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # sdist farms its file list building out to egg_info - if 'egg_info' in cmds: - _egg_info: Any = cmds['egg_info'] + if "egg_info" in cmds: + _egg_info: Any = cmds["egg_info"] else: from setuptools.command.egg_info import egg_info as _egg_info @@ -2068,7 +2119,7 @@ def find_sources(self) -> None: # Modify the filelist and normalize it root = get_root() cfg = get_config_from_root(root) - self.filelist.append('versioneer.py') + self.filelist.append("versioneer.py") if cfg.versionfile_source: # There are rare cases where versionfile_source might not be # included by default, so we must be explicit @@ -2081,18 +2132,21 @@ def find_sources(self) -> None: # We will instead replicate their final normalization (to unicode, # and POSIX-style paths) from setuptools import unicode_utils - normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') - for f in self.filelist.files] - manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') - with open(manifest_filename, 'w') as fobj: - fobj.write('\n'.join(normalized)) + normalized = [ + unicode_utils.filesys_decode(f).replace(os.sep, "/") + for f in self.filelist.files + ] + + manifest_filename = os.path.join(self.egg_info, "SOURCES.txt") + with open(manifest_filename, "w") as fobj: + fobj.write("\n".join(normalized)) - cmds['egg_info'] = cmd_egg_info + cmds["egg_info"] = cmd_egg_info # we override different "sdist" commands for both environments - if 'sdist' in cmds: - _sdist: Any = cmds['sdist'] + if "sdist" in cmds: + _sdist: Any = cmds["sdist"] else: from setuptools.command.sdist import sdist as _sdist @@ -2114,8 +2168,10 @@ def make_release_tree(self, base_dir: str, files: List[str]) -> None: # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + cmds["sdist"] = cmd_sdist return cmds @@ -2175,11 +2231,9 @@ def do_setup() -> int: root = get_root() try: cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (OSError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -2188,15 +2242,18 @@ def do_setup() -> int: print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: