diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db60bc5c8..a389706dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ exclude: ^docs repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.4 + rev: v0.3.5 hooks: - id: ruff args: [--fix, --ignore, D] diff --git a/docs_rst/conf.py b/docs_rst/conf.py index e24724158..75895e36f 100644 --- a/docs_rst/conf.py +++ b/docs_rst/conf.py @@ -309,5 +309,5 @@ def skip(app, what, name, obj, skip, options): # AJ: a hack found online to get __init__ to show up in docs -def setup(app): +def setup(app) -> None: app.connect("autodoc-skip-member", skip) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 853356d11..e95015e82 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -16,7 +16,7 @@ from collections import defaultdict from copy import deepcopy from datetime import datetime -from typing import Any, Iterator, Sequence +from typing import Any, Iterator, NoReturn, Sequence from monty.io import reverse_readline, zopen from monty.os.path import zpath @@ -49,7 +49,7 @@ class FiretaskBase(defaultdict, FWSerializable, abc.ABC): # if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init optional_params = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: dict.__init__(self, *args, **kwargs) required_params = self.required_params or [] @@ -68,7 +68,7 @@ def __init__(self, *args, **kwargs): ) @abc.abstractmethod - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> NoReturn: """ This method gets called when the Firetask is run. It can take in a Firework spec, perform some task using that data, and then return an @@ -101,7 +101,7 @@ def to_dict(self): def from_dict(cls, m_dict): return cls(m_dict) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.fw_name}>:{dict(self)}" # not strictly needed here for pickle/unpickle, but complements __setstate__ @@ -136,7 +136,7 @@ def __init__( defuse_children=False, defuse_workflow=False, propagate=False, - ): + ) -> None: """ Args: stored_data (dict): data to store from the run. Does not affect the operation of FireWorks. @@ -209,7 +209,7 @@ def skip_remaining_tasks(self): """ return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow - def __str__(self): + def __str__(self) -> str: return "FWAction\n" + pprint.pformat(self.to_dict()) @@ -241,7 +241,7 @@ def __init__( fw_id=None, parents=None, updated_on=None, - ): + ) -> None: """ Args: tasks (Firetask or [Firetask]): a list of Firetasks to run in sequence. @@ -286,7 +286,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state) -> None: """ Setter for the FW state, which triggers updated_on change. @@ -318,7 +318,7 @@ def to_dict(self): return m_dict - def _rerun(self): + def _rerun(self) -> None: """ Moves all Launches to archived Launches and resets the state to 'WAITING'. The Firework can thus be re-run even if it was Launched in the past. This method should be called by @@ -367,7 +367,7 @@ def from_dict(cls, m_dict): tasks, m_dict["spec"], name, launches, archived_launches, state, created_on, fw_id, updated_on=updated_on ) - def __str__(self): + def __str__(self) -> str: return f"Firework object: (id: {int(self.fw_id)} , name: {self.fw_name})" def __iter__(self) -> Iterator[FiretaskBase]: @@ -385,7 +385,7 @@ class Tracker(FWSerializable): MAX_TRACKER_LINES = 1000 - def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False): + def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False) -> None: """ Args: filename (str) @@ -437,7 +437,7 @@ def from_dict(cls, m_dict): m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False) ) - def __str__(self): + def __str__(self) -> str: return f"### Filename: {self.filename}\n{self.content}" @@ -456,7 +456,7 @@ def __init__( state_history=None, launch_id=None, fw_id=None, - ): + ) -> None: """ Args: state (str): the state of the Launch (e.g. RUNNING, COMPLETED) @@ -483,7 +483,7 @@ def __init__( self.launch_id = launch_id self.fw_id = fw_id - def touch_history(self, update_time=None, checkpoint=None): + def touch_history(self, update_time=None, checkpoint=None) -> None: """ Updates the update_on field of the state history of a Launch. Used to ping that a Launch is still alive. @@ -496,7 +496,7 @@ def touch_history(self, update_time=None, checkpoint=None): self.state_history[-1]["checkpoint"] = checkpoint self.state_history[-1]["updated_on"] = update_time - def set_reservation_id(self, reservation_id): + def set_reservation_id(self, reservation_id) -> None: """ Adds the job_id to the reservation. @@ -517,7 +517,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state) -> None: """ Setter for the Launch's state. Automatically triggers an update to state_history. @@ -627,7 +627,7 @@ def from_dict(cls, m_dict): m_dict["fw_id"], ) - def _update_state_history(self, state): + def _update_state_history(self, state) -> None: """ Internal method to update the state history whenever the Launch state is modified. @@ -675,7 +675,7 @@ class Workflow(FWSerializable): class Links(dict, FWSerializable): """An inner class for storing the DAG links between FireWorks.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) for k, v in list(self.items()): @@ -906,7 +906,7 @@ def apply_action(self, action: FWAction, fw_id: int) -> list[int]: # Traverse whole sub-workflow down to leaves. visited_cfid = set() # avoid double-updating for diamond deps - def recursive_update_spec(fw_id): + def recursive_update_spec(fw_id) -> None: for cfid in self.links[fw_id]: if cfid not in visited_cfid: visited_cfid.add(cfid) @@ -926,7 +926,7 @@ def recursive_update_spec(fw_id): if action.mod_spec and action.propagate: visited_cfid = set() - def recursive_mod_spec(fw_id): + def recursive_mod_spec(fw_id) -> None: for cfid in self.links[fw_id]: if cfid not in visited_cfid: visited_cfid.add(cfid) @@ -1349,10 +1349,10 @@ def from_Firework(cls, fw: Firework, name: str | None = None, metadata=None) -> name = name if name else fw.name return Workflow([fw], None, name=name, metadata=metadata, created_on=fw.created_on, updated_on=fw.updated_on) - def __str__(self): + def __str__(self) -> str: return f"Workflow object: (fw_ids: {[*self.id_fw]} , name: {self.name})" - def remove_fws(self, fw_ids): + def remove_fws(self, fw_ids) -> None: """ Remove the fireworks corresponding to the input firework ids and update the workflow i.e the parents of the removed fireworks become the parents of the children fireworks (only if the diff --git a/fireworks/core/fworker.py b/fireworks/core/fworker.py index 559e140ed..08c0395ac 100644 --- a/fireworks/core/fworker.py +++ b/fireworks/core/fworker.py @@ -19,7 +19,7 @@ class FWorker(FWSerializable): - def __init__(self, name="Automatically generated Worker", category="", query=None, env=None): + def __init__(self, name="Automatically generated Worker", category="", query=None, env=None) -> None: """ Args: name (str): the name of the resource, should be unique diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 885ae53f4..b30896249 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -91,7 +91,7 @@ class WFLock: Calling functions are responsible for handling the error in order to avoid database inconsistencies. """ - def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EXPIRATION_KILL): + def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EXPIRATION_KILL) -> None: """ Args: lp (LaunchPad) @@ -155,7 +155,7 @@ def __init__( authsource=None, uri_mode=False, mongoclient_kwargs=None, - ): + ) -> None: """ Args: host (str): hostname. If uri_mode is True, a MongoDB connection string URI @@ -243,7 +243,7 @@ def to_dict(self): "mongoclient_kwargs": self.mongoclient_kwargs, } - def update_spec(self, fw_ids, spec_document, mongo=False): + def update_spec(self, fw_ids, spec_document, mongo=False) -> None: """ Update fireworks with a spec. Sometimes you need to modify a firework in progress. @@ -300,7 +300,7 @@ def auto_load(cls): return LaunchPad.from_file(LAUNCHPAD_LOC) return LaunchPad() - def reset(self, password, require_password=True, max_reset_wo_password=25): + def reset(self, password, require_password=True, max_reset_wo_password=25) -> None: """ Create a new FireWorks database. This will overwrite the existing FireWorks database! To safeguard against accidentally erasing an existing database, a password must be entered. @@ -336,7 +336,7 @@ def reset(self, password, require_password=True, max_reset_wo_password=25): else: raise ValueError(f"Invalid password! Password is today's date: {m_password}") - def maintain(self, infinite=True, maintain_interval=None): + def maintain(self, infinite=True, maintain_interval=None) -> None: """ Perform launchpad maintenance: detect lost runs and unreserved RESERVE launches. @@ -397,7 +397,7 @@ def add_wf(self, wf, reassign_all=True): self.m_logger.info(f"Added a workflow. id_map: {old_new}") return old_new - def bulk_add_wfs(self, wfs): + def bulk_add_wfs(self, wfs) -> None: """ Adds a list of workflows to the fireworks database using insert_many for both the fws and wfs, is @@ -438,7 +438,7 @@ def bulk_add_wfs(self, wfs): self.fireworks.insert_many(fw.to_db_dict() for fw in all_fws) return - def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=True): + def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=True) -> None: """ Append a new workflow on top of an existing workflow. @@ -563,7 +563,7 @@ def get_wf_by_fw_id_lzyfw(self, fw_id: int) -> Workflow: fw_states, ) - def delete_fws(self, fw_ids, delete_launch_dirs=False): + def delete_fws(self, fw_ids, delete_launch_dirs=False) -> None: """Delete a set of fireworks identified by their fw_ids. ATTENTION: This function serves maintenance purposes and will leave @@ -611,7 +611,7 @@ def delete_fws(self, fw_ids, delete_launch_dirs=False): self.offline_runs.delete_many({"launch_id": {"$in": launch_ids}}) self.fireworks.delete_many({"fw_id": {"$in": fw_ids}}) - def delete_wf(self, fw_id, delete_launch_dirs=False): + def delete_wf(self, fw_id, delete_launch_dirs=False) -> None: """ Delete the workflow containing firework with the given id. @@ -885,7 +885,7 @@ def run_exists(self, fworker=None): q = fworker.query if fworker else {} return bool(self._get_a_fw_to_run(query=q, checkout=False)) - def future_run_exists(self, fworker=None): + def future_run_exists(self, fworker=None) -> bool: """Check if database has any current OR future Fireworks available. Returns: @@ -908,7 +908,7 @@ def future_run_exists(self, fworker=None): # there is no future work to do return False - def tuneup(self, bkground=True): + def tuneup(self, bkground=True) -> None: """Database tuneup: build indexes.""" self.m_logger.info("Performing db tune-up") @@ -1029,7 +1029,7 @@ def resume_fw(self, fw_id): self._refresh_wf(fw_id) return f - def defuse_wf(self, fw_id, defuse_all_states=True): + def defuse_wf(self, fw_id, defuse_all_states=True) -> None: """ Defuse the workflow containing the given firework id. @@ -1042,7 +1042,7 @@ def defuse_wf(self, fw_id, defuse_all_states=True): if fw.state not in ["COMPLETED", "FIZZLED"] or defuse_all_states: self.defuse_fw(fw.fw_id) - def pause_wf(self, fw_id): + def pause_wf(self, fw_id) -> None: """ Pause the workflow containing the given firework id. @@ -1054,7 +1054,7 @@ def pause_wf(self, fw_id): if fw.state not in ["COMPLETED", "FIZZLED", "DEFUSED"]: self.pause_fw(fw.fw_id) - def reignite_wf(self, fw_id): + def reignite_wf(self, fw_id) -> None: """ Reignite the workflow containing the given firework id. @@ -1065,7 +1065,7 @@ def reignite_wf(self, fw_id): for fw in wf: self.reignite_fw(fw.fw_id) - def archive_wf(self, fw_id): + def archive_wf(self, fw_id) -> None: """ Archive the workflow containing the given firework id. @@ -1087,7 +1087,7 @@ def archive_wf(self, fw_id): ) self._refresh_wf(fw.fw_id) - def _restart_ids(self, next_fw_id, next_launch_id): + def _restart_ids(self, next_fw_id, next_launch_id) -> None: """ internal method used to reset firework id counters. @@ -1101,7 +1101,7 @@ def _restart_ids(self, next_fw_id, next_launch_id): ) self.m_logger.debug(f"RESTARTED fw_id, launch_id to ({next_fw_id}, {next_launch_id})") - def _check_fw_for_uniqueness(self, m_fw): + def _check_fw_for_uniqueness(self, m_fw) -> bool: """ Check if there are duplicates. If not unique, a new id is assigned and the workflow refreshed. @@ -1202,7 +1202,7 @@ def get_fw_ids_from_reservation_id(self, reservation_id): l_id = self.launches.find_one({"state_history.reservation_id": reservation_id}, {"launch_id": 1})["launch_id"] return [fw["fw_id"] for fw in self.fireworks.find({"launches": l_id}, {"fw_id": 1})] - def cancel_reservation_by_reservation_id(self, reservation_id): + def cancel_reservation_by_reservation_id(self, reservation_id) -> None: """Given the reservation id, cancel the reservation and rerun the corresponding fireworks.""" l_id = self.launches.find_one( {"state_history.reservation_id": reservation_id, "state": "RESERVED"}, {"launch_id": 1} @@ -1223,7 +1223,7 @@ def get_reservation_id_from_fw_id(self, fw_id): return None return None - def cancel_reservation(self, launch_id): + def cancel_reservation(self, launch_id) -> None: """Given the launch id, cancel the reservation and rerun the fireworks.""" m_launch = self.get_launch_by_id(launch_id) m_launch.state = "READY" @@ -1264,7 +1264,7 @@ def detect_unreserved(self, expiration_secs=RESERVATION_EXPIRATION_SECS, rerun=F self.cancel_reservation(lid) return bad_launch_ids - def mark_fizzled(self, launch_id): + def mark_fizzled(self, launch_id) -> None: """ Mark the launch corresponding to the given id as FIZZLED. @@ -1381,7 +1381,7 @@ def detect_lostruns( return lost_launch_ids, lost_fw_ids, inconsistent_fw_ids - def set_reservation_id(self, launch_id, reservation_id): + def set_reservation_id(self, launch_id, reservation_id) -> None: """ Set reservation id to the launch corresponding to the given launch id. @@ -1471,7 +1471,7 @@ def checkout_fw(self, fworker, launch_dir, fw_id=None, host=None, ip=None, state return m_fw, launch_id - def change_launch_dir(self, launch_id, launch_dir): + def change_launch_dir(self, launch_id, launch_dir) -> None: """ Change the launch directory corresponding to the given launch id. @@ -1483,7 +1483,7 @@ def change_launch_dir(self, launch_id, launch_dir): m_launch.launch_dir = launch_dir self.launches.find_one_and_replace({"launch_id": m_launch.launch_id}, m_launch.to_db_dict(), upsert=True) - def restore_backup_data(self, launch_id, fw_id): + def restore_backup_data(self, launch_id, fw_id) -> None: """For the given launch id and firework id, restore the back up data.""" if launch_id in self.backup_launch_data: self.launches.find_one_and_replace({"launch_id": launch_id}, self.backup_launch_data[launch_id]) @@ -1541,7 +1541,7 @@ def complete_launch(self, launch_id, action=None, state="COMPLETED"): # change return type to dict to make return type serializable to support job packing return m_launch.to_dict() - def ping_launch(self, launch_id, ptime=None, checkpoint=None): + def ping_launch(self, launch_id, ptime=None, checkpoint=None) -> None: """ Ping that a Launch is still alive: updates the 'update_on 'field of the state history of a Launch. @@ -1715,7 +1715,7 @@ def get_recovery(self, fw_id, launch_id="last"): recovery.update({"_prev_dir": launch.launch_dir, "_launch_id": launch.launch_id}) return recovery - def _refresh_wf(self, fw_id): + def _refresh_wf(self, fw_id) -> None: """ Update the FW state of all jobs in workflow. @@ -1743,7 +1743,7 @@ def _refresh_wf(self, fw_id): err_message = f"Error refreshing workflow. The full stack trace is: {traceback.format_exc()}" raise RuntimeError(err_message) - def _update_wf(self, wf, updated_ids): + def _update_wf(self, wf, updated_ids) -> None: """ Update the workflow with the updated firework ids. Note: must be called within an enclosing WFLock. @@ -1821,7 +1821,7 @@ def _steal_launches(self, thief_fw): self.m_logger.info(f"Duplicate found! fwids {thief_fw.fw_id} and {potential_match['fw_id']}") return stolen - def set_priority(self, fw_id, priority): + def set_priority(self, fw_id, priority) -> None: """ Set priority to the firework with the given id. @@ -1839,7 +1839,7 @@ def get_logdir(self): """ return self.logdir - def add_offline_run(self, launch_id, fw_id, name): + def add_offline_run(self, launch_id, fw_id, name) -> None: """ Add the launch and firework to the offline_run collection. @@ -1954,7 +1954,7 @@ def recover_offline(self, launch_id, ignore_errors=False, print_errors=False): self.offline_runs.update_one({"launch_id": launch_id}, {"$set": {"completed": True}}) return m_launch.fw_id - def forget_offline(self, launchid_or_fwid, launch_mode=True): + def forget_offline(self, launchid_or_fwid, launch_mode=True) -> None: """ Unmark the offline run for the given launch or firework id. @@ -1990,7 +1990,7 @@ def get_launchdir(self, fw_id, launch_idx=-1): fw = self.get_fw_by_id(fw_id) return fw.launches[launch_idx].launch_dir if len(fw.launches) > 0 else None - def log_message(self, level, message): + def log_message(self, level, message) -> None: """ Support for job packing. @@ -2012,7 +2012,7 @@ class LazyFirework: db_fields = ("name", "fw_id", "spec", "created_on", "state") db_launch_fields = ("launches", "archived_launches") - def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs): + def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs) -> None: """ Args: fw_id (int): firework id @@ -2038,20 +2038,20 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state) -> None: self.partial_fw._state = state self.partial_fw.updated_on = datetime.datetime.utcnow() def to_dict(self): return self.full_fw.to_dict() - def _rerun(self): + def _rerun(self) -> None: self.full_fw._rerun() def to_db_dict(self): return self.full_fw.to_db_dict() - def __str__(self): + def __str__(self) -> str: return f"LazyFireWork object: (id: {self.fw_id})" # Properties that shadow FireWork attributes @@ -2061,7 +2061,7 @@ def tasks(self): return self.partial_fw.tasks @tasks.setter - def tasks(self, value): + def tasks(self, value) -> None: self.partial_fw.tasks = value @property @@ -2069,7 +2069,7 @@ def spec(self): return self.partial_fw.spec @spec.setter - def spec(self, value): + def spec(self, value) -> None: self.partial_fw.spec = value @property @@ -2077,7 +2077,7 @@ def name(self): return self.partial_fw.name @name.setter - def name(self, value): + def name(self, value) -> None: self.partial_fw.name = value @property @@ -2085,7 +2085,7 @@ def created_on(self): return self.partial_fw.created_on @created_on.setter - def created_on(self, value): + def created_on(self, value) -> None: self.partial_fw.created_on = value @property @@ -2093,7 +2093,7 @@ def updated_on(self): return self.partial_fw.updated_on @updated_on.setter - def updated_on(self, value): + def updated_on(self, value) -> None: self.partial_fw.updated_on = value @property @@ -2103,7 +2103,7 @@ def parents(self): return [] @parents.setter - def parents(self, value): + def parents(self, value) -> None: self.partial_fw.parents = value # Properties that shadow FireWork attributes, but which are @@ -2114,7 +2114,7 @@ def launches(self): return self._get_launch_data("launches") @launches.setter - def launches(self, value): + def launches(self, value) -> None: self._launches["launches"] = True self.partial_fw.launches = value @@ -2123,7 +2123,7 @@ def archived_launches(self): return self._get_launch_data("archived_launches") @archived_launches.setter - def archived_launches(self, value): + def archived_launches(self, value) -> None: self._launches["archived_launches"] = True self.partial_fw.archived_launches = value diff --git a/fireworks/core/rocket.py b/fireworks/core/rocket.py index 9720200e4..a6edd4fcd 100644 --- a/fireworks/core/rocket.py +++ b/fireworks/core/rocket.py @@ -73,7 +73,7 @@ def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Event | None: return ping_stop -def stop_backgrounds(ping_stop, btask_stops): +def stop_backgrounds(ping_stop, btask_stops) -> None: fd = FWData() if fd.MULTIPROCESSING: fd.Running_IDs[os.getpid()] = None @@ -84,7 +84,7 @@ def stop_backgrounds(ping_stop, btask_stops): b.set() -def background_task(btask, spec, stop_event, master_thread): +def background_task(btask, spec, stop_event, master_thread) -> None: num_launched = 0 while not stop_event.is_set() and master_thread.is_alive(): for task in btask.tasks: diff --git a/fireworks/core/rocket_launcher.py b/fireworks/core/rocket_launcher.py index 055b61a93..15ac7ebb1 100644 --- a/fireworks/core/rocket_launcher.py +++ b/fireworks/core/rocket_launcher.py @@ -68,7 +68,7 @@ def rapidfire( timeout: int | None = None, local_redirect: bool = False, pdb_on_exception: bool = False, -): +) -> None: """ Keeps running Rockets in m_dir until we reach an error. Automatically creates subdirectories for each Rocket. Usually stops when we run out of FireWorks from the LaunchPad. diff --git a/fireworks/core/tests/tasks.py b/fireworks/core/tests/tasks.py index 83aa0bf9f..f13b5ddef 100644 --- a/fireworks/core/tests/tasks.py +++ b/fireworks/core/tests/tasks.py @@ -1,4 +1,5 @@ import time +from typing import NoReturn from unittest import SkipTest from fireworks import FiretaskBase, Firework, FWAction @@ -6,7 +7,7 @@ class SerializableException(Exception): - def __init__(self, exc_details): + def __init__(self, exc_details) -> None: self.exc_details = exc_details def to_dict(self): @@ -17,7 +18,7 @@ def to_dict(self): class ExceptionTestTask(FiretaskBase): exec_counter = 0 - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: ExceptionTestTask.exec_counter += 1 if not fw_spec.get("skip_exception", False): raise SerializableException(self["exc_details"]) @@ -27,7 +28,7 @@ def run_task(self, fw_spec): class ExecutionCounterTask(FiretaskBase): exec_counter = 0 - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: ExecutionCounterTask.exec_counter += 1 @@ -39,7 +40,7 @@ def run_task(self, fw_spec): @explicit_serialize class TodictErrorTask(FiretaskBase): - def to_dict(self): + def to_dict(self) -> NoReturn: raise RuntimeError("to_dict error") def run_task(self, fw_spec): @@ -87,7 +88,7 @@ def run_task(self, fw_spec): @explicit_serialize class DoNothingTask(FiretaskBase): - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: pass diff --git a/fireworks/core/tests/test_firework.py b/fireworks/core/tests/test_firework.py index 093f24313..7d77ef0b1 100644 --- a/fireworks/core/tests/test_firework.py +++ b/fireworks/core/tests/test_firework.py @@ -16,7 +16,7 @@ class FiretaskBaseTest(unittest.TestCase): - def test_init(self): + def test_init(self) -> None: class DummyTask(FiretaskBase): required_params = ["hello"] @@ -37,7 +37,7 @@ class DummyTask2(FiretaskBase): with pytest.raises(NotImplementedError): d.run_task({}) - def test_param_checks(self): + def test_param_checks(self) -> None: class DummyTask(FiretaskBase): _fw_name = "DummyTask" required_params = ["param1"] @@ -59,14 +59,14 @@ def run_task(self, fw_spec): class FiretaskPickleTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: import pickle self.task = PickleTask(test=0) self.pkl_task = pickle.dumps(self.task) self.upkl_task = pickle.loads(self.pkl_task) - def test_init(self): + def test_init(self) -> None: assert isinstance(self.upkl_task, PickleTask) assert PickleTask.from_dict(self.task.to_dict()) == self.upkl_task assert dir(self.task) == dir(self.upkl_task) @@ -91,12 +91,12 @@ def run_task(self, fw_spec): class WorkflowTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.fw1 = Firework(Task1()) self.fw2 = Firework([Task2(), Task2()], parents=self.fw1) self.fw3 = Firework(Task1(), parents=self.fw1) - def test_init(self): + def test_init(self) -> None: fws = [] for i in range(5): fw = Firework([PyTask(func="print", args=[i])], fw_id=i) @@ -109,7 +109,7 @@ def test_init(self): with pytest.raises(ValueError): Workflow(fws, links_dict={0: [1, 2, 3], 1: [4], 2: [100]}) - def test_copy(self): + def test_copy(self) -> None: """Test that we can produce a copy of a Workflow but that the copy has unique fw_ids. """ @@ -134,7 +134,7 @@ def test_copy(self): for child_id, orig_child_id in zip(children, orig_children): assert orig_child_id == wf_copy.id_fw[child_id].name - def test_remove_leaf_fws(self): + def test_remove_leaf_fws(self) -> None: fw4 = Firework(Task1(), parents=[self.fw2, self.fw3]) fws = [self.fw1, self.fw2, self.fw3, fw4] wflow = Workflow(fws) @@ -145,7 +145,7 @@ def test_remove_leaf_fws(self): wflow.remove_fws(wflow.leaf_fw_ids) assert wflow.leaf_fw_ids == parents - def test_remove_root_fws(self): + def test_remove_root_fws(self) -> None: fw4 = Firework(Task1(), parents=[self.fw2, self.fw3]) fws = [self.fw1, self.fw2, self.fw3, fw4] wflow = Workflow(fws) @@ -156,7 +156,7 @@ def test_remove_root_fws(self): wflow.remove_fws(wflow.root_fw_ids) assert sorted(wflow.root_fw_ids) == sorted(children) - def test_iter_len_index(self): + def test_iter_len_index(self) -> None: fws = [self.fw1, self.fw2, self.fw3] wflow = Workflow(fws) for idx, fw in enumerate(wflow): diff --git a/fireworks/core/tests/test_launchpad.py b/fireworks/core/tests/test_launchpad.py index d60cefcd6..31026f8e2 100644 --- a/fireworks/core/tests/test_launchpad.py +++ b/fireworks/core/tests/test_launchpad.py @@ -41,27 +41,27 @@ class AuthenticationTest(unittest.TestCase): """Tests whether users are authenticating against the correct mongo dbs.""" @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: try: client = MongoClient() client.not_the_admin_db.command("createUser", "myuser", pwd="mypassword", roles=["dbOwner"]) except Exception: raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") - def test_no_admin_privileges_for_plebs(self): + def test_no_admin_privileges_for_plebs(self) -> None: """Normal users can not authenticate against the admin db.""" with pytest.raises(OperationFailure): lp = LaunchPad(name="admin", username="myuser", password="mypassword", authsource="admin") lp.db.collection.count_documents({}) - def test_authenticating_to_users_db(self): + def test_authenticating_to_users_db(self) -> None: """A user should be able to authenticate against a database that they are a user of. """ lp = LaunchPad(name="not_the_admin_db", username="myuser", password="mypassword", authsource="not_the_admin_db") lp.db.collection.count_documents({}) - def test_authsource_infered_from_db_name(self): + def test_authsource_infered_from_db_name(self) -> None: """The default behavior is to authenticate against the db that the user is trying to access. """ @@ -71,7 +71,7 @@ def test_authsource_infered_from_db_name(self): class LaunchPadTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -81,17 +81,17 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) cls.lp.connection - def setUp(self): + def setUp(self) -> None: self.old_wd = os.getcwd() self.LP_LOC = os.path.join(MODULE_DIR, "launchpad.yaml") self.lp.to_file(self.LP_LOC) - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False, max_reset_wo_password=1000) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -103,13 +103,13 @@ def tearDown(self): if os.path.exists(self.LP_LOC): os.remove(self.LP_LOC) - def test_dict_from_file(self): + def test_dict_from_file(self) -> None: lp = LaunchPad.from_file(self.LP_LOC) lp_dict = lp.to_dict() new_lp = LaunchPad.from_dict(lp_dict) assert isinstance(new_lp, LaunchPad) - def test_reset(self): + def test_reset(self) -> None: # Store some test fireworks # Attempt couple of ways to reset the lp and check fw = Firework(ScriptTask.from_str('echo "hello"'), name="hello") @@ -130,14 +130,14 @@ def test_reset(self): self.lp.reset("") self.lp.reset("", False, 100) # reset back - def test_pw_check(self): + def test_pw_check(self) -> None: fw = Firework(ScriptTask.from_str('echo "hello"'), name="hello") self.lp.add_wf(fw) args = ("",) with pytest.raises(ValueError): self.lp.reset(*args) - def test_add_wf(self): + def test_add_wf(self) -> None: fw = Firework(ScriptTask.from_str('echo "hello"'), name="hello") self.lp.add_wf(fw) wf_id = self.lp.get_wf_ids() @@ -154,7 +154,7 @@ def test_add_wf(self): assert len(fw_ids) == 3 self.lp.reset("", require_password=False) - def test_add_wfs(self): + def test_add_wfs(self) -> None: ftask = ScriptTask.from_str('echo "lorem ipsum"') wfs = [] for _ in range(50): @@ -172,7 +172,7 @@ def test_add_wfs(self): class LaunchPadDefuseReigniteRerunArchiveDeleteTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -182,11 +182,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) - def setUp(self): + def setUp(self) -> None: # define the individual FireWorks used in the Workflow # Parent Firework fw_p = Firework( @@ -291,7 +291,7 @@ def setUp(self): self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -301,12 +301,12 @@ def tearDown(self): shutil.rmtree(ldir) @staticmethod - def _teardown(dests): + def _teardown(dests) -> None: for f in dests: if os.path.exists(f): os.remove(f) - def test_pause_fw(self): + def test_pause_fw(self) -> None: self.lp.pause_fw(self.zeus_fw_id) paused_ids = self.lp.get_fw_ids({"state": "PAUSED"}) @@ -339,7 +339,7 @@ def test_pause_fw(self): except Exception: raise - def test_defuse_fw(self): + def test_defuse_fw(self) -> None: # defuse Zeus self.lp.defuse_fw(self.zeus_fw_id) @@ -363,7 +363,7 @@ def test_defuse_fw(self): except Exception: raise - def test_defuse_fw_after_completion(self): + def test_defuse_fw_after_completion(self) -> None: # Launch rockets in rapidfire rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) # defuse Zeus @@ -374,7 +374,7 @@ def test_defuse_fw_after_completion(self): completed_ids = set(self.lp.get_fw_ids({"state": "COMPLETED"})) assert not self.zeus_child_fw_ids.issubset(completed_ids) - def test_reignite_fw(self): + def test_reignite_fw(self) -> None: # Defuse Zeus self.lp.defuse_fw(self.zeus_fw_id) defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"}) @@ -392,7 +392,7 @@ def test_reignite_fw(self): assert self.zeus_fw_id in completed_ids assert self.zeus_child_fw_ids.issubset(completed_ids) - def test_pause_wf(self): + def test_pause_wf(self) -> None: # pause Workflow containing Zeus self.lp.pause_wf(self.zeus_fw_id) paused_ids = self.lp.get_fw_ids({"state": "PAUSED"}) @@ -405,7 +405,7 @@ def test_pause_wf(self): fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) assert fws_no_run == self.all_ids - def test_defuse_wf(self): + def test_defuse_wf(self) -> None: # defuse Workflow containing Zeus self.lp.defuse_wf(self.zeus_fw_id) defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"}) @@ -418,7 +418,7 @@ def test_defuse_wf(self): fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) assert fws_no_run == self.all_ids - def test_defuse_wf_after_partial_run(self): + def test_defuse_wf_after_partial_run(self) -> None: # Run a firework before defusing Zeus launch_rocket(self.lp, self.fworker) print("----------\nafter launch rocket\n--------") @@ -440,7 +440,7 @@ def test_defuse_wf_after_partial_run(self): fws_no_run = set(self.lp.get_fw_ids({"state": "COMPLETED"})) assert len(fws_no_run) == 0 - def test_reignite_wf(self): + def test_reignite_wf(self) -> None: # Defuse workflow containing Zeus self.lp.defuse_wf(self.zeus_fw_id) defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"}) @@ -457,7 +457,7 @@ def test_reignite_wf(self): fws_completed = set(self.lp.get_fw_ids({"state": "COMPLETED"})) assert fws_completed == self.all_ids - def test_archive_wf(self): + def test_archive_wf(self) -> None: # Run a firework before archiving Zeus launch_rocket(self.lp, self.fworker) @@ -475,7 +475,7 @@ def test_archive_wf(self): fw = self.lp.get_fw_by_id(self.zeus_fw_id) assert fw.state == "ARCHIVED" - def test_delete_wf(self): + def test_delete_wf(self) -> None: # Run a firework before deleting Zeus rapidfire(self.lp, self.fworker, nlaunches=1) @@ -497,7 +497,7 @@ def test_delete_wf(self): # Check that the launch dir has not been deleted assert os.path.isdir(first_ldir) - def test_delete_wf_and_files(self): + def test_delete_wf_and_files(self) -> None: # Run a firework before deleting Zeus rapidfire(self.lp, self.fworker, nlaunches=1) @@ -519,7 +519,7 @@ def test_delete_wf_and_files(self): # Check that the launch dir has not been deleted assert not os.path.isdir(first_ldir) - def test_rerun_fws2(self): + def test_rerun_fws2(self) -> None: # Launch all fireworks rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) fw = self.lp.get_fw_by_id(self.zeus_fw_id) @@ -561,7 +561,7 @@ def test_rerun_fws2(self): @unittest.skipIf(PYMONGO_MAJOR_VERSION > 3, "detect lostruns test not supported for pymongo major version > 3") class LaunchPadLostRunsDetectTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -571,11 +571,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) - def setUp(self): + def setUp(self) -> None: # Define a timed fireWork fw_timer = Firework(PyTask(func="time.sleep", args=[5]), name="timer") self.lp.add_wf(fw_timer) @@ -585,7 +585,7 @@ def setUp(self): self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -595,15 +595,15 @@ def tearDown(self): shutil.rmtree(ldir) # self.lp.connection.close() - def test_detect_lostruns(self): + def test_detect_lostruns(self) -> None: # Launch the timed firework in a separate process class RocketProcess(Process): - def __init__(self, lpad, fworker): + def __init__(self, lpad, fworker) -> None: super(self.__class__, self).__init__() self.lpad = lpad self.fworker = fworker - def run(self): + def run(self) -> None: launch_rocket(self.lpad, self.fworker) rp = RocketProcess(self.lp, self.fworker) @@ -634,15 +634,15 @@ def run(self): assert (lost_launch_ids, lost_fw_ids) == ([1], [1]) assert self.lp.get_fw_by_id(1).state == "READY" - def test_detect_lostruns_defuse(self): + def test_detect_lostruns_defuse(self) -> None: # Launch the timed firework in a separate process class RocketProcess(Process): - def __init__(self, lpad, fworker): + def __init__(self, lpad, fworker) -> None: super(self.__class__, self).__init__() self.lpad = lpad self.fworker = fworker - def run(self): + def run(self) -> None: launch_rocket(self.lpad, self.fworker) rp = RocketProcess(self.lp, self.fworker) @@ -666,15 +666,15 @@ def run(self): assert (lost_launch_ids, lost_fw_ids) == ([1], []) assert self.lp.get_fw_by_id(1).state == "DEFUSED" - def test_state_after_run_start(self): + def test_state_after_run_start(self) -> None: # Launch the timed firework in a separate process class RocketProcess(Process): - def __init__(self, lpad, fworker): + def __init__(self, lpad, fworker) -> None: super(self.__class__, self).__init__() self.lpad = lpad self.fworker = fworker - def run(self): + def run(self) -> None: launch_rocket(self.lpad, self.fworker) rp = RocketProcess(self.lp, self.fworker) @@ -703,7 +703,7 @@ class WorkflowFireworkStatesTest(unittest.TestCase): """ @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -713,11 +713,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) - def setUp(self): + def setUp(self) -> None: # define the individual FireWorks used in the Workflow # Parent Firework fw_p = Firework( @@ -821,7 +821,7 @@ def setUp(self): self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -831,12 +831,12 @@ def tearDown(self): shutil.rmtree(ldir) @staticmethod - def _teardown(dests): + def _teardown(dests) -> None: for f in dests: if os.path.exists(f): os.remove(f) - def test_defuse_fw(self): + def test_defuse_fw(self) -> None: # defuse Zeus self.lp.defuse_fw(self.zeus_fw_id) # Ensure the states are sync after defusing fw @@ -860,7 +860,7 @@ def test_defuse_fw(self): except Exception: raise - def test_defuse_fw_after_completion(self): + def test_defuse_fw_after_completion(self) -> None: # Launch rockets in rapidfire rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) # defuse Zeus @@ -873,7 +873,7 @@ def test_defuse_fw_after_completion(self): fw_cache_state = wf.fw_states[fw_id] assert fw_state == fw_cache_state - def test_reignite_fw(self): + def test_reignite_fw(self) -> None: # Defuse Zeus and launch remaining fireworks self.lp.defuse_fw(self.zeus_fw_id) rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) @@ -888,7 +888,7 @@ def test_reignite_fw(self): fw_cache_state = wf.fw_states[fw_id] assert fw_state == fw_cache_state - def test_defuse_wf(self): + def test_defuse_wf(self) -> None: # defuse Workflow containing Zeus self.lp.defuse_wf(self.zeus_fw_id) defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"}) @@ -902,7 +902,7 @@ def test_defuse_wf(self): fw_cache_state = wf.fw_states[fw_id] assert fw_state == fw_cache_state - def test_reignite_wf(self): + def test_reignite_wf(self) -> None: # Defuse workflow containing Zeus self.lp.defuse_wf(self.zeus_fw_id) @@ -919,7 +919,7 @@ def test_reignite_wf(self): fw_cache_state = wf.fw_states[fw_id] assert fw_state == fw_cache_state - def test_archive_wf(self): + def test_archive_wf(self) -> None: # Run a firework before archiving Zeus launch_rocket(self.lp, self.fworker) # archive Workflow containing Zeus. @@ -932,7 +932,7 @@ def test_archive_wf(self): fw_cache_state = wf.fw_states[fw_id] assert fw_state == fw_cache_state - def test_rerun_fws(self): + def test_rerun_fws(self) -> None: # Launch all fireworks rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) fw = self.lp.get_fw_by_id(self.zeus_fw_id) @@ -949,15 +949,15 @@ def test_rerun_fws(self): fw_cache_state = wf.fw_states[fw_id] assert fw_state == fw_cache_state - def test_rerun_timed_fws(self): + def test_rerun_timed_fws(self) -> None: # Launch all fireworks in a separate process class RapidfireProcess(Process): - def __init__(self, lpad, fworker): + def __init__(self, lpad, fworker) -> None: super(self.__class__, self).__init__() self.lpad = lpad self.fworker = fworker - def run(self): + def run(self) -> None: rapidfire(self.lpad, self.fworker) rp = RapidfireProcess(self.lp, self.fworker) @@ -1013,7 +1013,7 @@ def run(self): class LaunchPadRerunExceptionTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -1023,11 +1023,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) - def setUp(self): + def setUp(self) -> None: fireworks.core.firework.EXCEPT_DETAILS_ON_RERUN = True self.error_test_dict = {"error": "description", "error_code": 1} @@ -1044,7 +1044,7 @@ def setUp(self): self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -1053,14 +1053,14 @@ def tearDown(self): for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")): shutil.rmtree(ldir) - def test_except_details_on_rerun(self): + def test_except_details_on_rerun(self) -> None: rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) assert os.getcwd() == MODULE_DIR self.lp.rerun_fw(1) fw = self.lp.get_fw_by_id(1) assert fw.spec["_exception_details"] == self.error_test_dict - def test_task_level_rerun(self): + def test_task_level_rerun(self) -> None: rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) assert os.getcwd() == MODULE_DIR self.lp.rerun_fw(1, recover_launch="last") @@ -1077,7 +1077,7 @@ def test_task_level_rerun(self): fw = self.lp.get_fw_by_id(1) assert "_recovery" not in fw.spec - def test_task_level_rerun_cp(self): + def test_task_level_rerun_cp(self) -> None: rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) assert os.getcwd() == MODULE_DIR self.lp.rerun_fw(1, recover_launch="last", recover_mode="cp") @@ -1090,7 +1090,7 @@ def test_task_level_rerun_cp(self): assert ExceptionTestTask.exec_counter == 2 assert filecmp.cmp(os.path.join(dirs[0], "date_file"), os.path.join(dirs[1], "date_file")) - def test_task_level_rerun_prev_dir(self): + def test_task_level_rerun_prev_dir(self) -> None: rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) assert os.getcwd() == MODULE_DIR self.lp.rerun_fw(1, recover_launch="last", recover_mode="prev_dir") @@ -1106,7 +1106,7 @@ def test_task_level_rerun_prev_dir(self): class WFLockTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -1116,11 +1116,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) - def setUp(self): + def setUp(self) -> None: # set the defaults in the init of wflock to break the lock quickly fireworks.core.launchpad.WFLock(3, False).__init__.__func__.__defaults__ = (3, False) @@ -1133,7 +1133,7 @@ def setUp(self): self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -1142,15 +1142,15 @@ def tearDown(self): for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")): shutil.rmtree(ldir) - def test_fix_db_inconsistencies_completed(self): + def test_fix_db_inconsistencies_completed(self) -> None: class RocketProcess(Process): - def __init__(self, lpad, fworker, fw_id): + def __init__(self, lpad, fworker, fw_id) -> None: super(self.__class__, self).__init__() self.lpad = lpad self.fworker = fworker self.fw_id = fw_id - def run(self): + def run(self) -> None: launch_rocket(self.lpad, self.fworker, fw_id=self.fw_id) # Launch the slow firework in a separate process @@ -1189,15 +1189,15 @@ def run(self): assert fast_fw.state == "COMPLETED" - def test_fix_db_inconsistencies_fizzled(self): + def test_fix_db_inconsistencies_fizzled(self) -> None: class RocketProcess(Process): - def __init__(self, lpad, fworker, fw_id): + def __init__(self, lpad, fworker, fw_id) -> None: super(self.__class__, self).__init__() self.lpad = lpad self.fworker = fworker self.fw_id = fw_id - def run(self): + def run(self) -> None: launch_rocket(self.lpad, self.fworker, fw_id=self.fw_id) self.lp.update_spec([2], {"fizzle": True}) @@ -1237,7 +1237,7 @@ def run(self): class LaunchPadOfflineTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -1247,11 +1247,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) - def setUp(self): + def setUp(self) -> None: fireworks.core.firework.EXCEPT_DETAILS_ON_RERUN = True self.error_test_dict = {"error": "description", "error_code": 1} @@ -1263,7 +1263,7 @@ def setUp(self): self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -1272,7 +1272,7 @@ def tearDown(self): for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")): shutil.rmtree(ldir, ignore_errors=True) - def test__recover_completed(self): + def test__recover_completed(self) -> None: fw, launch_id = self.lp.reserve_fw(self.fworker, self.launch_dir) fw = self.lp.get_fw_by_id(1) with cd(self.launch_dir): @@ -1287,7 +1287,7 @@ def test__recover_completed(self): assert fw.state == "COMPLETED" - def test_recover_errors(self): + def test_recover_errors(self) -> None: fw, launch_id = self.lp.reserve_fw(self.fworker, self.launch_dir) fw = self.lp.get_fw_by_id(1) with cd(self.launch_dir): @@ -1318,7 +1318,7 @@ class GridfsStoredDataTest(unittest.TestCase): """ @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -1328,14 +1328,14 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TEST_DB_NAME) - def setUp(self): + def setUp(self) -> None: self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): @@ -1344,7 +1344,7 @@ def tearDown(self): for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")): shutil.rmtree(ldir) - def test_many_detours(self): + def test_many_detours(self) -> None: task = DetoursTask(n_detours=2000, data_per_detour=["a" * 100] * 100) fw = Firework([task]) self.lp.add_wf(fw) @@ -1367,7 +1367,7 @@ def test_many_detours(self): wf = self.lp.get_wf_by_fw_id_lzyfw(1) assert len(wf.id_fw[1].launches[0].action.detours) == 2000 - def test_many_detours_offline(self): + def test_many_detours_offline(self) -> None: task = DetoursTask(n_detours=2000, data_per_detour=["a" * 100] * 100) fw = Firework([task]) self.lp.add_wf(fw) diff --git a/fireworks/core/tests/test_rocket.py b/fireworks/core/tests/test_rocket.py index 75dc02a89..5dac1fa09 100644 --- a/fireworks/core/tests/test_rocket.py +++ b/fireworks/core/tests/test_rocket.py @@ -11,7 +11,7 @@ class RocketTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -21,20 +21,20 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TESTDB_NAME) - def setUp(self): + def setUp(self) -> None: pass - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) # Delete launch locations if os.path.exists(os.path.join("FW.json")): os.remove("FW.json") - def test_serializable_exception(self): + def test_serializable_exception(self) -> None: error_test_dict = {"error": "description", "error_code": 1} fw = Firework(ExceptionTestTask(exc_details=error_test_dict)) self.lp.add_wf(fw) @@ -45,7 +45,7 @@ def test_serializable_exception(self): launches = fw.launches assert launches[0].action.stored_data["_exception"]["_details"] == error_test_dict - def test_postproc_exception(self): + def test_postproc_exception(self) -> None: fw = Firework(MalformedAdditionTask()) self.lp.add_wf(fw) launch_rocket(self.lp, self.fworker) diff --git a/fireworks/core/tests/test_tracker.py b/fireworks/core/tests/test_tracker.py index 0f203d618..3914310ac 100644 --- a/fireworks/core/tests/test_tracker.py +++ b/fireworks/core/tests/test_tracker.py @@ -23,7 +23,7 @@ class TrackerTest(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -33,11 +33,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TESTDB_NAME) - def setUp(self): + def setUp(self) -> None: self.old_wd = os.getcwd() self.dest1 = os.path.join(MODULE_DIR, "numbers1.txt") self.dest2 = os.path.join(MODULE_DIR, "numbers2.txt") @@ -45,7 +45,7 @@ def setUp(self): self.tracker1 = Tracker(self.dest1, nlines=2) self.tracker2 = Tracker(self.dest2, nlines=2) - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) if os.path.exists(os.path.join("FW.json")): os.remove("FW.json") @@ -54,12 +54,12 @@ def tearDown(self): shutil.rmtree(i) @staticmethod - def _teardown(dests): + def _teardown(dests) -> None: for f in dests: if os.path.exists(f): os.remove(f) - def test_tracker(self): + def test_tracker(self) -> None: """Launch a workflow and track the files.""" self._teardown([self.dest1]) try: @@ -78,7 +78,7 @@ def test_tracker(self): finally: self._teardown([self.dest1]) - def test_tracker_failed_fw(self): + def test_tracker_failed_fw(self) -> None: """Add a bad firetask to workflow and test the tracking.""" self._teardown([self.dest1]) try: @@ -107,12 +107,12 @@ def test_tracker_failed_fw(self): finally: self._teardown([self.dest1]) - def test_tracker_mlaunch(self): + def test_tracker_mlaunch(self) -> None: """Test the tracker for mlaunch.""" self._teardown([self.dest1, self.dest2]) try: - def add_wf(j, dest, tracker, name): + def add_wf(j, dest, tracker, name) -> None: fts = [] for i in range(j, j + 25): ft = ScriptTask.from_str('echo "' + str(i) + '" >> ' + dest, {"store_stdout": True}) diff --git a/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py b/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py index 04c30462f..b6dc31592 100644 --- a/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py +++ b/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py @@ -3,5 +3,5 @@ @explicit_serialize class HelloTask(FiretaskBase): - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: print("Hello, world!") diff --git a/fireworks/examples/custom_firetasks/merge_task/merge_task.py b/fireworks/examples/custom_firetasks/merge_task/merge_task.py index 273e3f92d..c356e5d41 100644 --- a/fireworks/examples/custom_firetasks/merge_task/merge_task.py +++ b/fireworks/examples/custom_firetasks/merge_task/merge_task.py @@ -24,7 +24,7 @@ def run_task(self, fw_spec): @explicit_serialize class TaskC(FiretaskBase): - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: print("This is task C.") print(f"Task A gave me: {fw_spec['param_A']}") print(f"Task B gave me: {fw_spec['param_B']}") diff --git a/fireworks/features/background_task.py b/fireworks/features/background_task.py index af5f9e6ad..88fc6c722 100644 --- a/fireworks/features/background_task.py +++ b/fireworks/features/background_task.py @@ -10,7 +10,7 @@ class BackgroundTask(FWSerializable): _fw_name = "BackgroundTask" - def __init__(self, tasks, num_launches=0, sleep_time=60, run_on_finish=False): + def __init__(self, tasks, num_launches=0, sleep_time=60, run_on_finish=False) -> None: """ Args: tasks [Firetask]: a list of Firetasks to perform diff --git a/fireworks/features/dupefinder.py b/fireworks/features/dupefinder.py index 62027f657..a543df4a2 100644 --- a/fireworks/features/dupefinder.py +++ b/fireworks/features/dupefinder.py @@ -1,5 +1,7 @@ """This module contains the base class for implementing Duplicate Finders.""" +from typing import NoReturn + from fireworks.utilities.fw_serializers import FWSerializable, serialize_fw __author__ = "Anubhav Jain" @@ -12,10 +14,10 @@ class DupeFinderBase(FWSerializable): """This serves an Abstract class for implementing Duplicate Finders.""" - def __init__(self): + def __init__(self) -> None: pass - def verify(self, spec1, spec2): + def verify(self, spec1, spec2) -> NoReturn: """ Method that checks whether two specs are identical enough to be considered duplicates. Return true if duplicated. Note that @@ -31,7 +33,7 @@ def verify(self, spec1, spec2): """ raise NotImplementedError - def query(self, spec): + def query(self, spec) -> NoReturn: """ Given a spec, returns a database query that gives potential candidates for duplicated Fireworks. diff --git a/fireworks/features/fw_report.py b/fireworks/features/fw_report.py index 8b6f09e93..d40ea1dde 100644 --- a/fireworks/features/fw_report.py +++ b/fireworks/features/fw_report.py @@ -25,7 +25,7 @@ class FWReport: - def __init__(self, lpad): + def __init__(self, lpad) -> None: """ Args: lpad (LaunchPad). diff --git a/fireworks/features/introspect.py b/fireworks/features/introspect.py index a8863b6ea..ecc82c2e8 100644 --- a/fireworks/features/introspect.py +++ b/fireworks/features/introspect.py @@ -80,7 +80,7 @@ def compare_stats(stats_dict1, n_samples1, stats_dict2, n_samples2, threshold=5) class Introspector: - def __init__(self, lpad): + def __init__(self, lpad) -> None: """ Args: lpad (LaunchPad). @@ -171,7 +171,7 @@ def introspect_fizzled(self, coll="fws", rsort=True, threshold=10, limit=100): return table @staticmethod - def print_report(table, coll): + def print_report(table, coll) -> None: if coll.lower() in ["fws", "fireworks"]: header_txt = "fireworks.spec" elif coll.lower() == "tasks": diff --git a/fireworks/features/multi_launcher.py b/fireworks/features/multi_launcher.py index 03b93ca34..546b846a4 100644 --- a/fireworks/features/multi_launcher.py +++ b/fireworks/features/multi_launcher.py @@ -16,7 +16,7 @@ __date__ = "Aug 19, 2013" -def ping_multilaunch(port, stop_event): +def ping_multilaunch(port, stop_event) -> None: """ A single manager to ping all launches during multiprocess launches. @@ -43,7 +43,7 @@ def ping_multilaunch(port, stop_event): def rapidfire_process( fworker, nlaunches, sleep, loglvl, port, node_list, sub_nproc, timeout, running_ids_dict, local_redirect -): +) -> None: """ Initializes shared data with multiprocessing parameters and starts a rapidfire. @@ -205,7 +205,7 @@ def launch_multiprocess( timeout=None, exclude_current_node=False, local_redirect=False, -): +) -> None: """ Launch the jobs in the job packing mode. diff --git a/fireworks/features/stats.py b/fireworks/features/stats.py index 39b395880..c381e64d6 100644 --- a/fireworks/features/stats.py +++ b/fireworks/features/stats.py @@ -22,7 +22,7 @@ class FWStats: - def __init__(self, lpad): + def __init__(self, lpad) -> None: """ Object to get Fireworks running stats from a LaunchPad. diff --git a/fireworks/features/tests/test_introspect.py b/fireworks/features/tests/test_introspect.py index 2137ad829..74f78ded4 100644 --- a/fireworks/features/tests/test_introspect.py +++ b/fireworks/features/tests/test_introspect.py @@ -6,7 +6,7 @@ class IntrospectTest(unittest.TestCase): - def test_flatten_dict(self): + def test_flatten_dict(self) -> None: assert set(flatten_to_keys({"d": {"e": {"f": 4}, "f": 10}}, max_recurs=1)) == { f"d{separator_str}" } diff --git a/fireworks/flask_site/gunicorn.py b/fireworks/flask_site/gunicorn.py index a65c25a3a..7618f23ff 100755 --- a/fireworks/flask_site/gunicorn.py +++ b/fireworks/flask_site/gunicorn.py @@ -12,12 +12,12 @@ def number_of_workers(): class StandaloneApplication(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): + def __init__(self, app, options=None) -> None: self.options = options or {} self.application = app super().__init__() - def load_config(self): + def load_config(self) -> None: config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) diff --git a/fireworks/flask_site/helpers.py b/fireworks/flask_site/helpers.py index 454acd00b..19196fd9d 100644 --- a/fireworks/flask_site/helpers.py +++ b/fireworks/flask_site/helpers.py @@ -16,8 +16,7 @@ def fw_filt_given_wf_filt(filt, lp): def wf_filt_given_fw_filt(filt, lp): wf_ids = set() - for doc in lp.fireworks.find(filt, {"_id": 0, "fw_id": 1}): - wf_ids.add(doc["fw_id"]) + wf_ids.update(doc["fw_id"] for doc in lp.fireworks.find(filt, {"_id": 0, "fw_id": 1})) return {"nodes": {"$in": list(wf_ids)}} diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py index fb3e5b423..24235e19c 100644 --- a/fireworks/fw_config.py +++ b/fireworks/fw_config.py @@ -177,7 +177,7 @@ def write_config(path: str | None = None) -> None: class FWData: """This class stores data that a Firetask might want to access, e.g. to see the runtime params.""" - def __init__(self): + def __init__(self) -> None: self.MULTIPROCESSING = None # default single process framework self.NODE_LIST = None # the node list for sub jobs self.SUB_NPROCS = None # the number of process of the sub job diff --git a/fireworks/queue/queue_adapter.py b/fireworks/queue/queue_adapter.py index 87296c129..37b806d18 100644 --- a/fireworks/queue/queue_adapter.py +++ b/fireworks/queue/queue_adapter.py @@ -34,7 +34,7 @@ class Command: status = None output, error = "", "" - def __init__(self, command): + def __init__(self, command) -> None: """ initialize the object. @@ -57,7 +57,7 @@ def run(self, timeout=None, **kwargs): (status, output, error) """ - def target(**kwargs): + def target(**kwargs) -> None: try: self.process = subprocess.Popen(self.command, **kwargs) self.output, self.error = self.process.communicate() diff --git a/fireworks/queue/queue_launcher.py b/fireworks/queue/queue_launcher.py index 4b2b5b497..b99827c93 100644 --- a/fireworks/queue/queue_launcher.py +++ b/fireworks/queue/queue_launcher.py @@ -179,7 +179,7 @@ def rapidfire( strm_lvl="INFO", timeout=None, fill_mode=False, -): +) -> None: """ Submit many jobs to the queue. @@ -330,7 +330,7 @@ def _get_number_of_jobs_in_queue(qadapter, njobs_queue, l_logger): raise RuntimeError("Unable to determine number of jobs in queue, check queue adapter and queue server status!") -def setup_offline_job(launchpad, fw, launch_id): +def setup_offline_job(launchpad, fw, launch_id) -> None: # separate this function out for reuse in unit testing fw.to_file("FW.json") with open("FW_offline.json", "w") as f: diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 1c0165f6b..647d73dd3 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -683,7 +683,7 @@ def set_priority(args: Namespace) -> None: lp.m_logger.info(f"Finished setting priorities of {len(fw_ids)} FWs") -def _open_webbrowser(url): +def _open_webbrowser(url) -> None: """Open a web browser after a delay to give the web server more startup time.""" import webbrowser diff --git a/fireworks/scripts/qlaunch_run.py b/fireworks/scripts/qlaunch_run.py index a634429f3..1f3a8f9fe 100644 --- a/fireworks/scripts/qlaunch_run.py +++ b/fireworks/scripts/qlaunch_run.py @@ -34,7 +34,7 @@ __date__ = "Jan 14, 2013" -def do_launch(args): +def do_launch(args) -> None: cfg_files_to_check = [ ("launchpad", "-l", False, LAUNCHPAD_LOC), ("fworker", "-w", False, FWORKER_LOC), diff --git a/fireworks/scripts/rlaunch_run.py b/fireworks/scripts/rlaunch_run.py index 51bb188e2..ce3150c40 100644 --- a/fireworks/scripts/rlaunch_run.py +++ b/fireworks/scripts/rlaunch_run.py @@ -26,7 +26,7 @@ __date__ = "Feb 7, 2013" -def handle_interrupt(signum, frame): +def handle_interrupt(signum, frame) -> None: sys.stderr.write(f"Interrupted by signal {signum:d}\n") sys.exit(1) diff --git a/fireworks/scripts/tests/test_lpad_run.py b/fireworks/scripts/tests/test_lpad_run.py index 4f2451d2f..d5244c676 100644 --- a/fireworks/scripts/tests/test_lpad_run.py +++ b/fireworks/scripts/tests/test_lpad_run.py @@ -20,7 +20,7 @@ def lp(capsys): @pytest.mark.parametrize(("detail", "expected_1", "expected_2"), [("count", "0\n", "1\n"), ("ids", "[]\n", "1\n")]) -def test_lpad_get_fws(capsys, lp, detail, expected_1, expected_2): +def test_lpad_get_fws(capsys, lp, detail, expected_1, expected_2) -> None: """Test lpad CLI get_fws command.""" ret_code = lpad(["get_fws", "-d", detail]) assert ret_code == 0 @@ -45,7 +45,7 @@ def test_lpad_get_fws(capsys, lp, detail, expected_1, expected_2): @pytest.mark.parametrize("arg", ["-v", "--version"]) -def test_lpad_report_version(capsys, arg): +def test_lpad_report_version(capsys, arg) -> None: """Test lpad CLI version flag.""" with pytest.raises(SystemExit, match="0"): lpad([arg]) @@ -56,7 +56,7 @@ def test_lpad_report_version(capsys, arg): assert stderr == "" -def test_lpad_config_file_flags(): +def test_lpad_config_file_flags() -> None: """Test lpad CLI throws errors on missing config file flags.""" with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"): lpad(["-l", "", "get_fws"]) diff --git a/fireworks/scripts/tests/test_mlaunch_run.py b/fireworks/scripts/tests/test_mlaunch_run.py index 1c7827bd4..aa6f2f9a4 100644 --- a/fireworks/scripts/tests/test_mlaunch_run.py +++ b/fireworks/scripts/tests/test_mlaunch_run.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("arg", ["-v", "--version"]) -def test_mlaunch_report_version(capsys, arg): +def test_mlaunch_report_version(capsys, arg) -> None: """Test mlaunch CLI version flag.""" with pytest.raises(SystemExit, match="0"): mlaunch([arg]) @@ -17,7 +17,7 @@ def test_mlaunch_report_version(capsys, arg): assert stderr == "" -def test_mlaunch_config_file_flags(): +def test_mlaunch_config_file_flags() -> None: """Test mlaunch CLI throws errors on missing config file flags.""" num_jobs = "1" diff --git a/fireworks/scripts/tests/test_qlaunch_run.py b/fireworks/scripts/tests/test_qlaunch_run.py index 6c04a7318..3200f0980 100644 --- a/fireworks/scripts/tests/test_qlaunch_run.py +++ b/fireworks/scripts/tests/test_qlaunch_run.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize("arg", ["-v", "--version"]) -def test_qlaunch_report_version(capsys, arg): +def test_qlaunch_report_version(capsys, arg) -> None: """Test qlaunch CLI version flag.""" with pytest.raises(SystemExit): qlaunch([arg]) @@ -19,7 +19,7 @@ def test_qlaunch_report_version(capsys, arg): assert stderr == "" -def test_qlaunch_config_file_flags(): +def test_qlaunch_config_file_flags() -> None: """Test qlaunch CLI throws errors on missing config file flags.""" # qadapter.yaml is mandatory, test for ValueError if missing with pytest.raises(ValueError, match="No path specified for qadapter_file."): diff --git a/fireworks/scripts/tests/test_rlaunch_run.py b/fireworks/scripts/tests/test_rlaunch_run.py index 2895e4e1d..bc1f60327 100644 --- a/fireworks/scripts/tests/test_rlaunch_run.py +++ b/fireworks/scripts/tests/test_rlaunch_run.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("arg", ["-v", "--version"]) -def test_rlaunch_report_version(capsys, arg): +def test_rlaunch_report_version(capsys, arg) -> None: """Test rlaunch CLI version flag.""" with pytest.raises(SystemExit, match="0"): rlaunch([arg]) @@ -17,7 +17,7 @@ def test_rlaunch_report_version(capsys, arg): assert stderr == "" -def test_rlaunch_config_file_flags(): +def test_rlaunch_config_file_flags() -> None: """Test rlaunch CLI throws errors on missing config file flags.""" with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"): rlaunch(["-l", ""]) diff --git a/fireworks/tests/master_tests.py b/fireworks/tests/master_tests.py index 73e369b5d..e1dbbc3ba 100644 --- a/fireworks/tests/master_tests.py +++ b/fireworks/tests/master_tests.py @@ -23,7 +23,7 @@ class TestImports(unittest.TestCase): """Make sure that required external libraries can be imported.""" - def test_imports(self): + def test_imports(self) -> None: pass # test that MongoClient is available (newer pymongo) @@ -31,7 +31,7 @@ def test_imports(self): class BasicTests(unittest.TestCase): """Make sure that required external libraries can be imported.""" - def test_fwconnector(self): + def test_fwconnector(self) -> None: fw1 = Firework(ScriptTask.from_str('echo "1"')) fw2 = Firework(ScriptTask.from_str('echo "1"')) @@ -44,7 +44,7 @@ def test_fwconnector(self): wf3 = Workflow([fw1, fw2]) assert wf3.links == {fw1.fw_id: [], fw2.fw_id: []} - def test_parentconnector(self): + def test_parentconnector(self) -> None: fw1 = Firework(ScriptTask.from_str('echo "1"')) fw2 = Firework(ScriptTask.from_str('echo "1"'), parents=fw1) fw3 = Firework(ScriptTask.from_str('echo "1"'), parents=[fw1, fw2]) @@ -69,7 +69,7 @@ def get_data(obj_dict): return cls_.from_dict(obj_dict) return None - def test_serialization_details(self): + def test_serialization_details(self) -> None: # This detects a weird bug found in early version of serializers pbs = CommonAdapter("PBS") @@ -78,7 +78,7 @@ def test_serialization_details(self): assert isinstance(load_object(pbs.to_dict()), CommonAdapter) assert isinstance(self.get_data(pbs.to_dict()), CommonAdapter) # repeated test on purpose! - def test_recursive_deserialize(self): + def test_recursive_deserialize(self) -> None: my_dict = { "update_spec": {}, "mod_spec": [], diff --git a/fireworks/tests/mongo_tests.py b/fireworks/tests/mongo_tests.py index bef1aa983..d3c0fd950 100644 --- a/fireworks/tests/mongo_tests.py +++ b/fireworks/tests/mongo_tests.py @@ -6,6 +6,7 @@ import time import unittest from multiprocessing import Pool +from typing import NoReturn import pytest @@ -36,14 +37,14 @@ NCORES_PARALLEL_TEST = 4 -def random_launch(lp_creds): +def random_launch(lp_creds) -> None: lp = LaunchPad.from_dict(lp_creds) while lp.run_exists(None): launch_rocket(lp) time.sleep(random.random() / 3 + 0.1) -def throw_error(msg): +def throw_error(msg) -> NoReturn: raise ValueError(msg) @@ -74,7 +75,7 @@ def run_task(self, fw_spec): class MongoTests(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None cls.fworker = FWorker() try: @@ -84,21 +85,21 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TESTDB_NAME) @staticmethod - def _teardown(dests): + def _teardown(dests) -> None: for f in dests: if os.path.exists(f): os.remove(f) - def setUp(self): + def setUp(self) -> None: self.lp.reset(password=None, require_password=False) self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) if os.path.exists(os.path.join("FW.json")): os.remove("FW.json") @@ -110,14 +111,14 @@ def tearDown(self): for i in glob.glob(os.path.join(MODULE_DIR, "launcher*")): shutil.rmtree(i) - def test_basic_fw(self): + def test_basic_fw(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) fw = Firework(test1) self.lp.add_wf(fw) launch_rocket(self.lp, self.fworker) assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test1\n" - def test_basic_fw_offline(self): + def test_basic_fw_offline(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) fw = Firework(test1) self.lp.add_wf(fw) @@ -152,7 +153,7 @@ def test_basic_fw_offline(self): self.lp.recover_offline(launch["launch_id"]) assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test1\n" - def test_offline_fw_passinfo(self): + def test_offline_fw_passinfo(self) -> None: fw1 = Firework([AdditionTask()], {"input_array": [1, 1]}, name="1") fw2 = Firework([AdditionTask()], {"input_array": [2, 2]}, name="2") fw3 = Firework([AdditionTask()], {"input_array": [3]}, parents=[fw1, fw2], name="3") @@ -198,7 +199,7 @@ def test_offline_fw_passinfo(self): assert set(child_fw.spec["input_array"]) == {2, 3, 4} assert child_fw.launches[0].action.stored_data["sum"] == 9 - def test_multi_fw(self): + def test_multi_fw(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) test2 = ScriptTask.from_str("python -c 'print(\"test2\")'", {"store_stdout": True}) fw = Firework([test1, test2]) @@ -206,7 +207,7 @@ def test_multi_fw(self): launch_rocket(self.lp, self.fworker) assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test2\n" - def test_multi_fw_complex(self): + def test_multi_fw_complex(self) -> None: dest1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "inputs.txt") dest2 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_file.txt") self._teardown([dest1, dest2]) @@ -234,7 +235,7 @@ def test_multi_fw_complex(self): finally: self._teardown([dest1, dest2]) - def test_backgroundtask(self): + def test_backgroundtask(self) -> None: dest1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hello.txt") self._teardown([dest1]) @@ -256,13 +257,13 @@ def test_backgroundtask(self): finally: self._teardown([dest1]) - def test_add_fw(self): + def test_add_fw(self) -> None: fw = Firework(AdditionTask(), {"input_array": [5, 7]}) self.lp.add_wf(fw) rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) assert self.lp.get_launch_by_id(1).action.stored_data["sum"] == 12 - def test_org_wf(self): + def test_org_wf(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) test2 = ScriptTask.from_str("python -c 'print(\"test2\")'", {"store_stdout": True}) fw1 = Firework(test1, fw_id=-1) @@ -274,7 +275,7 @@ def test_org_wf(self): launch_rocket(self.lp, self.fworker) assert self.lp.get_launch_by_id(2).action.stored_data["stdout"] == "test2\n" - def test_fibadder(self): + def test_fibadder(self) -> None: fib = FibonacciAdderTask() fw = Firework(fib, {"smaller": 0, "larger": 1, "stop_point": 3}) self.lp.add_wf(fw) @@ -285,7 +286,7 @@ def test_fibadder(self): assert self.lp.get_launch_by_id(3).action.stored_data == {} assert not self.lp.run_exists() - def test_parallel_fibadder(self): + def test_parallel_fibadder(self) -> None: # this is really testing to see if a Workflow can handle multiple FWs updating it at once parent = Firework(ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})) fib1 = Firework(FibonacciAdderTask(), {"smaller": 0, "larger": 1, "stop_point": 30}, parents=[parent]) @@ -300,7 +301,7 @@ def test_parallel_fibadder(self): creds_array = [self.lp.to_dict()] * NCORES_PARALLEL_TEST p.map(random_launch, creds_array) - def test_multi_detour(self): + def test_multi_detour(self) -> None: fw1 = Firework([MultipleDetourTask()], fw_id=1) fw2 = Firework([ScriptTask.from_str('echo "DONE"')], parents=[fw1], fw_id=2) self.lp.add_wf(Workflow([fw1, fw2])) @@ -312,7 +313,7 @@ def test_multi_detour(self): assert set(links[4]) == {2} assert set(links[5]) == {2} - def test_fw_env(self): + def test_fw_env(self) -> None: t = DummyFWEnvTask() fw = Firework(t) self.lp.add_wf(fw) @@ -322,7 +323,7 @@ def test_fw_env(self): launch_rocket(self.lp, FWorker(env={"hello": "world"})) assert self.lp.get_launch_by_id(2).action.stored_data["data"] == "world" - def test_job_info(self): + def test_job_info(self) -> None: fw1 = Firework([ScriptTask.from_str('echo "Testing job info"')], spec={"_pass_job_info": True}, fw_id=1) fw2 = Firework([DummyJobPassTask()], parents=[fw1], spec={"_pass_job_info": True, "target": 1}, fw_id=2) fw3 = Firework([DummyJobPassTask()], parents=[fw2], spec={"target": 2}, fw_id=3) @@ -361,7 +362,7 @@ def test_job_info(self): assert len(modified_spec["_job_info"]) == 2 - def test_files_in_out(self): + def test_files_in_out(self) -> None: # create the Workflow that passes files_in and files_out fw1 = Firework( [ScriptTask.from_str('echo "This is the first FireWork" > test1')], @@ -390,7 +391,7 @@ def test_files_in_out(self): for f in ["test1", "hello.gz", "fwtest.2"]: os.remove(f) - def test_preserve_fworker(self): + def test_preserve_fworker(self) -> None: fw1 = Firework( [ScriptTask.from_str('echo "Testing preserve FWorker"')], spec={"_preserve_fworker": True}, fw_id=1 ) @@ -415,14 +416,14 @@ def test_preserve_fworker(self): assert modified_spec["_fworker"] is not None - def test_add_lp_and_fw_id(self): + def test_add_lp_and_fw_id(self) -> None: fw1 = Firework([DummyLPTask()], spec={"_add_launchpad_and_fw_id": True}) self.lp.add_wf(fw1) launch_rocket(self.lp, self.fworker) assert self.lp.get_launch_by_id(1).action.stored_data["fw_id"] == 1 assert self.lp.get_launch_by_id(1).action.stored_data["host"] is not None - def test_spec_copy(self): + def test_spec_copy(self) -> None: task1 = ScriptTask.from_str('echo "Task 1"') task2 = ScriptTask.from_str('echo "Task 2"') @@ -436,7 +437,7 @@ def test_spec_copy(self): assert self.lp.get_fw_by_id(1).tasks[0]["script"][0] == 'echo "Task 1"' assert self.lp.get_fw_by_id(2).tasks[0]["script"][0] == 'echo "Task 2"' - def test_category(self): + def test_category(self) -> None: task1 = ScriptTask.from_str('echo "Task 1"') task2 = ScriptTask.from_str('echo "Task 2"') @@ -453,7 +454,7 @@ def test_category(self): assert self.lp.run_exists(FWorker()) # can run any category assert self.lp.run_exists(FWorker(category=["dummy_category", "other category"])) - def test_category_pt2(self): + def test_category_pt2(self) -> None: task1 = ScriptTask.from_str('echo "Task 1"') task2 = ScriptTask.from_str('echo "Task 2"') @@ -467,7 +468,7 @@ def test_category_pt2(self): assert self.lp.run_exists(FWorker()) # can run any category assert not self.lp.run_exists(FWorker(category=["dummy_category", "other category"])) - def test_delete_fw(self): + def test_delete_fw(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) fw = Firework(test1) self.lp.add_wf(fw) @@ -479,7 +480,7 @@ def test_delete_fw(self): with pytest.raises(ValueError): self.lp.get_launch_by_id(1) - def test_duplicate_delete_fw(self): + def test_duplicate_delete_fw(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) fw = Firework(test1, {"_dupefinder": DupeFinderExact()}) self.lp.add_wf(fw) @@ -496,7 +497,7 @@ def test_duplicate_delete_fw(self): self.lp.get_fw_by_id(del_id) assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test1\n" - def test_dupefinder(self): + def test_dupefinder(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) fw = Firework(test1, {"_dupefinder": DupeFinderExact()}) self.lp.add_wf(fw) @@ -516,7 +517,7 @@ def test_dupefinder(self): print("--------") assert self.lp.launches.count_documents({}) == 1 - def test_append_wf(self): + def test_append_wf(self) -> None: fw1 = Firework([UpdateSpecTask()]) fw2 = Firework([ModSpecTask()]) self.lp.add_wf(Workflow([fw1, fw2])) @@ -548,7 +549,7 @@ def test_append_wf(self): with pytest.raises(ValueError): self.lp.append_wf(new_wf, [4], detour=True) - def test_append_wf_detour(self): + def test_append_wf_detour(self) -> None: fw1 = Firework([ModSpecTask()], fw_id=1) fw2 = Firework([ModSpecTask()], fw_id=2, parents=[fw1]) self.lp.add_wf(Workflow([fw1, fw2])) @@ -561,7 +562,7 @@ def test_append_wf_detour(self): assert self.lp.get_fw_by_id(2).spec["dummy2"] == [True, True] - def test_force_lock_removal(self): + def test_force_lock_removal(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) fw = Firework(test1, {"_dupefinder": DupeFinderExact()}, fw_id=1) self.lp.add_wf(fw) @@ -569,7 +570,7 @@ def test_force_lock_removal(self): with WFLock(self.lp, 1), WFLock(self.lp, 1, kill=True, expire_secs=1): assert True # dummy to make sure we got here - def test_fizzle(self): + def test_fizzle(self) -> None: p = PyTask(func="fireworks.tests.mongo_tests.throw_error", args=["Testing; this error is normal."]) fw = Firework(p) self.lp.add_wf(fw) @@ -577,21 +578,21 @@ def test_fizzle(self): assert self.lp.get_fw_by_id(1).state == "FIZZLED" assert not launch_rocket(self.lp, self.fworker) - def test_defuse(self): + def test_defuse(self) -> None: p = PyTask(func="fireworks.tests.mongo_tests.throw_error", args=["This should not happen"]) fw = Firework(p) self.lp.add_wf(fw) self.lp.defuse_fw(fw.fw_id) assert not launch_rocket(self.lp, self.fworker) - def test_archive(self): + def test_archive(self) -> None: p = PyTask(func="fireworks.tests.mongo_tests.throw_error", args=["This should not happen"]) fw = Firework(p) self.lp.add_wf(fw) self.lp.archive_wf(fw.fw_id) assert not launch_rocket(self.lp, self.fworker) - def test_stats(self): + def test_stats(self) -> None: test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}) fw = Firework(test1) self.lp.add_wf(fw) diff --git a/fireworks/tests/multiprocessing_tests.py b/fireworks/tests/multiprocessing_tests.py index 48ca49705..e6a0a1e8d 100644 --- a/fireworks/tests/multiprocessing_tests.py +++ b/fireworks/tests/multiprocessing_tests.py @@ -17,7 +17,7 @@ class TestLinks(TestCase): - def test_pickle(self): + def test_pickle(self) -> None: links1 = Workflow.Links({1: 2, 3: [5, 7, 8]}) s = pickle.dumps(links1) links2 = pickle.loads(s) @@ -28,7 +28,7 @@ class TestCheckoutFW(TestCase): lp = None @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.fworker = FWorker() try: cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") @@ -37,14 +37,14 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost: 27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TESTDB_NAME) - def setUp(self): + def setUp(self) -> None: self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) os.chdir(self.old_wd) if os.path.exists(os.path.join("FW.json")): @@ -53,7 +53,7 @@ def tearDown(self): for i in glob.glob(os.path.join(MODULE_DIR, "launcher*")): shutil.rmtree(i) - def test_checkout_fw(self): + def test_checkout_fw(self) -> None: os.chdir(MODULE_DIR) self.lp.add_wf( Firework(ScriptTask.from_str(shell_cmd='echo "hello 1"', parameters={"stdout_file": "task.out"}), fw_id=1) @@ -76,7 +76,7 @@ class TestEarlyExit(TestCase): lp = None @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.fworker = FWorker() try: cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") @@ -85,14 +85,14 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TESTDB_NAME) - def setUp(self): + def setUp(self) -> None: self.old_wd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: self.lp.reset(password=None, require_password=False) os.chdir(self.old_wd) if os.path.exists(os.path.join("FW.json")): @@ -101,7 +101,7 @@ def tearDown(self): for i in glob.glob(os.path.join(MODULE_DIR, "launcher*")): shutil.rmtree(i) - def test_early_exit(self): + def test_early_exit(self) -> None: os.chdir(MODULE_DIR) script_text = "echo hello from process $PPID; sleep 2" fw1 = Firework(ScriptTask.from_str(shell_cmd=script_text, parameters={"stdout_file": "task.out"}), fw_id=1) diff --git a/fireworks/tests/test_fw_config.py b/fireworks/tests/test_fw_config.py index d08677565..ab8a0ef46 100644 --- a/fireworks/tests/test_fw_config.py +++ b/fireworks/tests/test_fw_config.py @@ -10,7 +10,7 @@ class ConfigTest(unittest.TestCase): - def test_config(self): + def test_config(self) -> None: d = config_to_dict() assert "NEGATIVE_FWID_CTR" not in d diff --git a/fireworks/tests/test_workflow.py b/fireworks/tests/test_workflow.py index 655361be6..875bcd194 100644 --- a/fireworks/tests/test_workflow.py +++ b/fireworks/tests/test_workflow.py @@ -4,55 +4,55 @@ class TestWorkflowState(unittest.TestCase): - def test_completed(self): + def test_completed(self) -> None: # all leaves complete one = fw.Firework([], state="COMPLETED", fw_id=1) two = fw.Firework([], state="COMPLETED", fw_id=2) assert fw.Workflow([one, two]).state == "COMPLETED" - def test_archived(self): + def test_archived(self) -> None: one = fw.Firework([], state="ARCHIVED", fw_id=1) two = fw.Firework([], state="ARCHIVED", fw_id=2) assert fw.Workflow([one, two]).state == "ARCHIVED" - def test_defused(self): + def test_defused(self) -> None: # any defused == defused one = fw.Firework([], state="COMPLETED", fw_id=1) two = fw.Firework([], state="DEFUSED", fw_id=2) assert fw.Workflow([one, two]).state == "DEFUSED" - def test_paused(self): + def test_paused(self) -> None: # any paused == paused one = fw.Firework([], state="COMPLETED", fw_id=1) two = fw.Firework([], state="PAUSED", fw_id=2) assert fw.Workflow([one, two]).state == "PAUSED" - def test_fizzled_1(self): + def test_fizzled_1(self) -> None: # WF(Fizzled -> Waiting(no fizz parents)) == FIZZLED one = fw.Firework([], state="FIZZLED", fw_id=1) two = fw.Firework([], state="WAITING", fw_id=2, parents=one) assert fw.Workflow([one, two]).state == "FIZZLED" - def test_fizzled_2(self): + def test_fizzled_2(self) -> None: # WF(Fizzled -> Ready(allow fizz parents)) == RUNNING one = fw.Firework([], state="FIZZLED", fw_id=1) two = fw.Firework([], state="READY", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one) assert fw.Workflow([one, two]).state == "RUNNING" - def test_fizzled_3(self): + def test_fizzled_3(self) -> None: # WF(Fizzled -> Completed(allow fizz parents)) == COMPLETED one = fw.Firework([], state="FIZZLED", fw_id=1) two = fw.Firework([], state="COMPLETED", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one) assert fw.Workflow([one, two]).state == "COMPLETED" - def test_fizzled_4(self): + def test_fizzled_4(self) -> None: # one child doesn't allow fizzled parents one = fw.Firework([], state="FIZZLED", fw_id=1) two = fw.Firework([], state="READY", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one) @@ -60,39 +60,39 @@ def test_fizzled_4(self): assert fw.Workflow([one, two, three]).state == "FIZZLED" - def test_fizzled_5(self): + def test_fizzled_5(self) -> None: # leaf is fizzled, wf is fizzled one = fw.Firework([], state="COMPLETED", fw_id=1) two = fw.Firework([], state="FIZZLED", fw_id=2, parents=one) assert fw.Workflow([one, two]).state == "FIZZLED" - def test_fizzled_6(self): + def test_fizzled_6(self) -> None: # deep fizzled fireworks, but still RUNNING one = fw.Firework([], state="FIZZLED", fw_id=1) two = fw.Firework([], state="FIZZLED", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one) three = fw.Firework([], state="READY", fw_id=3, spec={"_allow_fizzled_parents": True}, parents=two) assert fw.Workflow([one, two, three]).state == "RUNNING" - def test_running_1(self): + def test_running_1(self) -> None: one = fw.Firework([], state="COMPLETED", fw_id=1) two = fw.Firework([], state="READY", fw_id=2, parents=one) assert fw.Workflow([one, two]).state == "RUNNING" - def test_running_2(self): + def test_running_2(self) -> None: one = fw.Firework([], state="RUNNING", fw_id=1) two = fw.Firework([], state="WAITING", fw_id=2, parents=one) assert fw.Workflow([one, two]).state == "RUNNING" - def test_reserved(self): + def test_reserved(self) -> None: one = fw.Firework([], state="RESERVED", fw_id=1) two = fw.Firework([], state="READY", fw_id=2, parents=one) assert fw.Workflow([one, two]).state == "RESERVED" - def test_ready(self): + def test_ready(self) -> None: one = fw.Firework([], state="READY", fw_id=1) two = fw.Firework([], state="READY", fw_id=2, parents=one) diff --git a/fireworks/user_objects/firetasks/fileio_tasks.py b/fireworks/user_objects/firetasks/fileio_tasks.py index 29bcc8452..7ff5923e8 100755 --- a/fireworks/user_objects/firetasks/fileio_tasks.py +++ b/fireworks/user_objects/firetasks/fileio_tasks.py @@ -31,7 +31,7 @@ class FileWriteTask(FiretaskBase): required_params = ["files_to_write"] optional_params = ["dest"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: pth = self.get("dest", os.getcwd()) for d in self["files_to_write"]: with open(os.path.join(pth, d["filename"]), "w") as f: @@ -54,7 +54,7 @@ class FileDeleteTask(FiretaskBase): required_params = ["files_to_delete"] optional_params = ["dest", "ignore_errors"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: pth = self.get("dest", os.getcwd()) ignore_errors = self.get("ignore_errors", True) for f in self["files_to_delete"]: @@ -98,7 +98,7 @@ class FileTransferTask(FiretaskBase): "copyfile": shutil.copyfile, } - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: shell_interpret = self.get("shell_interpret", True) ignore_errors = self.get("ignore_errors", False) max_retry = self.get("max_retry", 0) @@ -162,7 +162,7 @@ def run_task(self, fw_spec): ssh.close() @staticmethod - def _rexists(sftp, path): + def _rexists(sftp, path) -> bool: """os.path.exists for paramiko's SCP object.""" try: sftp.stat(path) @@ -187,7 +187,7 @@ class CompressDirTask(FiretaskBase): _fw_name = "CompressDirTask" optional_params = ["compression", "dest", "ignore_errors"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: ignore_errors = self.get("ignore_errors", False) dest = self.get("dest", os.getcwd()) compression = self.get("compression", "gz") @@ -211,7 +211,7 @@ class DecompressDirTask(FiretaskBase): _fw_name = "DecompressDirTask" optional_params = ["dest", "ignore_errors"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: ignore_errors = self.get("ignore_errors", False) dest = self.get("dest", os.getcwd()) try: @@ -235,5 +235,5 @@ class ArchiveDirTask(FiretaskBase): required_params = ["base_name"] optional_params = ["format"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: shutil.make_archive(self["base_name"], format=self.get("format", "gztar"), root_dir=".") diff --git a/fireworks/user_objects/firetasks/filepad_tasks.py b/fireworks/user_objects/firetasks/filepad_tasks.py index b4bec5d5b..27f918d09 100644 --- a/fireworks/user_objects/firetasks/filepad_tasks.py +++ b/fireworks/user_objects/firetasks/filepad_tasks.py @@ -27,7 +27,7 @@ class AddFilesTask(FiretaskBase): required_params = ["paths"] optional_params = ["identifiers", "directory", "filepad_file", "compress", "metadata"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: from glob import glob directory = os.path.abspath(self.get("directory", ".")) @@ -68,7 +68,7 @@ class GetFilesTask(FiretaskBase): required_params = ["identifiers"] optional_params = ["filepad_file", "dest_dir", "new_file_names"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: fpad = get_fpad(self.get("filepad_file", None)) dest_dir = self.get("dest_dir", os.path.abspath(".")) new_file_names = self.get("new_file_names", []) @@ -142,7 +142,7 @@ class GetFilesByQueryTask(FiretaskBase): "sort_key", ] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: import json import pymongo @@ -206,7 +206,7 @@ class DeleteFilesTask(FiretaskBase): required_params = ["identifiers"] optional_params = ["filepad_file"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: fpad = get_fpad(self.get("filepad_file", None)) for file in self["identifiers"]: fpad.delete_file(file) diff --git a/fireworks/user_objects/firetasks/script_task.py b/fireworks/user_objects/firetasks/script_task.py index aa2463b29..b86dc3294 100644 --- a/fireworks/user_objects/firetasks/script_task.py +++ b/fireworks/user_objects/firetasks/script_task.py @@ -92,7 +92,7 @@ def _run_task_internal(self, fw_spec, stdin): return FWAction(stored_data=output) - def _load_params(self, d): + def _load_params(self, d) -> None: if d.get("stdin_file") and d.get("stdin_key"): raise ValueError("ScriptTask cannot process both a key and file as the standard in!") diff --git a/fireworks/user_objects/firetasks/templatewriter_task.py b/fireworks/user_objects/firetasks/templatewriter_task.py index 2fd9e0438..22dc082ef 100644 --- a/fireworks/user_objects/firetasks/templatewriter_task.py +++ b/fireworks/user_objects/firetasks/templatewriter_task.py @@ -31,7 +31,7 @@ class TemplateWriterTask(FiretaskBase): _fw_name = "TemplateWriterTask" - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: if self.get("use_global_spec"): self._load_params(fw_spec) else: @@ -45,7 +45,7 @@ def run_task(self, fw_spec): with open(self.output_file, write_mode) as of: of.write(output) - def _load_params(self, d): + def _load_params(self, d) -> None: self.context = d["context"] self.output_file = d["output_file"] self.append_file = d.get("append") # append to output file? diff --git a/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py b/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py index 454043769..098213fa1 100644 --- a/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py +++ b/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py @@ -28,7 +28,7 @@ def afunc(array, power): class CommandLineTaskTest(unittest.TestCase): """run tests for CommandLineTask.""" - def test_command_line_task_1(self): + def test_command_line_task_1(self) -> None: """Input from string to stdin, output from stdout to string.""" params = { "command_spec": { @@ -49,7 +49,7 @@ def test_command_line_task_1(self): output_string = action.mod_spec[0]["_push"]["output string"]["value"] assert output_string == "Hello world!" - def test_command_line_task_2(self): + def test_command_line_task_2(self) -> None: """ input from string to data, output from stdout to file; input from file to stdin, output from stdout to string and from file. @@ -94,7 +94,7 @@ def test_command_line_task_2(self): os.remove(filename) os.remove(output_file) - def test_command_line_task_3(self): + def test_command_line_task_3(self) -> None: """Input from string to data with command line options.""" import platform @@ -155,7 +155,7 @@ def test_command_line_task_3(self): assert time_stamp_1[11:19] == time_stamp_2[11:19] os.remove(filename) - def test_command_line_task_4(self): + def test_command_line_task_4(self) -> None: """Multiple string inputs, multiple file outputs.""" params = { "command_spec": { @@ -188,7 +188,7 @@ def test_command_line_task_4(self): class ForeachTaskTest(unittest.TestCase): """run tests for ForeachTask.""" - def test_foreach_pytask(self): + def test_foreach_pytask(self) -> None: """Run PyTask for a list of numbers.""" numbers = [0, 1, 2, 3, 4] power = 2 @@ -209,7 +209,7 @@ def test_foreach_pytask(self): for number, result in zip(numbers, results): assert result == pow(number, power) - def test_foreach_commandlinetask(self): + def test_foreach_commandlinetask(self) -> None: """Run CommandLineTask for a list of input data.""" inputs = ["black", "white", 2.5, 17] worklist = [{"source": {"type": "data", "value": s}} for s in inputs] @@ -242,7 +242,7 @@ def test_foreach_commandlinetask(self): class JoinDictTaskTest(unittest.TestCase): """run tests for JoinDictTask.""" - def test_join_dict_task(self): + def test_join_dict_task(self) -> None: """Joins dictionaries into a new or existing dict in spec.""" temperature = {"value": 273.15, "units": "Kelvin"} pressure = {"value": 1.2, "units": "bar"} @@ -268,7 +268,7 @@ def test_join_dict_task(self): class JoinListTaskTest(unittest.TestCase): """run tests for JoinListTask.""" - def test_join_list_task(self): + def test_join_list_task(self) -> None: """Joins items into a new or existing list in spec.""" temperature = {"value": 273.15, "units": "Kelvin"} pressure = {"value": 1.2, "units": "bar"} @@ -291,7 +291,7 @@ def test_join_list_task(self): class ImportDataTaskTest(unittest.TestCase): """run tests for ImportDataTask.""" - def test_import_data_task(self): + def test_import_data_task(self) -> None: """Loads data from a file into spec.""" import json diff --git a/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py b/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py index 067e9ec6f..81f34c0e2 100644 --- a/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py +++ b/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py @@ -21,13 +21,13 @@ class FileWriteDeleteTest(unittest.TestCase): - def test_init(self): + def test_init(self) -> None: FileWriteTask(files_to_write="hello") FileWriteTask({"files_to_write": "hello"}) with pytest.raises(RuntimeError): FileWriteTask() - def test_run(self): + def test_run(self) -> None: t = load_object_from_file(os.path.join(module_dir, "write.yaml")) t.run_task({}) for i in range(2): @@ -41,11 +41,11 @@ def test_run(self): class CompressDecompressArchiveDirTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.cwd = os.getcwd() os.chdir(module_dir) - def test_compress_dir(self): + def test_compress_dir(self) -> None: c = CompressDirTask(compression="gz") c.run_task({}) assert os.path.exists("delete.yaml.gz") @@ -55,13 +55,13 @@ def test_compress_dir(self): assert not os.path.exists("delete.yaml.gz") assert os.path.exists("delete.yaml") - def test_archive_dir(self): + def test_archive_dir(self) -> None: a = ArchiveDirTask(base_name="archive", format="gztar") a.run_task({}) assert os.path.exists("archive.tar.gz") os.remove("archive.tar.gz") - def tearDown(self): + def tearDown(self) -> None: os.chdir(self.cwd) diff --git a/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py b/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py index cc447df04..7b807a681 100644 --- a/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py +++ b/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py @@ -18,12 +18,12 @@ class FilePadTasksTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.paths = [os.path.join(module_dir, "write.yaml"), os.path.join(module_dir, "delete.yaml")] self.identifiers = ["write", "delete"] self.fp = FilePad.auto_load() - def test_addfilestask_run(self): + def test_addfilestask_run(self) -> None: t = AddFilesTask(paths=self.paths, identifiers=self.identifiers) t.run_task({}) write_file_contents, _ = self.fp.get_file("write") @@ -33,7 +33,7 @@ def test_addfilestask_run(self): with open(self.paths[1]) as f: assert del_file_contents == f.read().encode() - def test_deletefilestask_run(self): + def test_deletefilestask_run(self) -> None: t = DeleteFilesTask(identifiers=self.identifiers) t.run_task({}) file_contents, doc = self.fp.get_file("write") @@ -43,7 +43,7 @@ def test_deletefilestask_run(self): assert file_contents is None assert doc is None - def test_getfilestask_run(self): + def test_getfilestask_run(self) -> None: t = AddFilesTask(paths=self.paths, identifiers=self.identifiers) t.run_task({}) dest_dir = os.path.abspath(".") @@ -56,7 +56,7 @@ def test_getfilestask_run(self): assert write_file_contents == f.read().encode() os.remove(os.path.join(dest_dir, new_file_names[0])) - def test_getfilesbyquerytask_run(self): + def test_getfilesbyquerytask_run(self) -> None: """Tests querying objects from FilePad by metadata.""" t = AddFilesTask(paths=self.paths, identifiers=self.identifiers, metadata={"key": "value"}) t.run_task({}) @@ -69,7 +69,7 @@ def test_getfilesbyquerytask_run(self): assert test_file_contents == file.read().encode() os.remove(os.path.join(dest_dir, new_file_names[0])) - def test_getfilesbyquerytask_run(self): + def test_getfilesbyquerytask_run(self) -> None: """Tests querying objects from FilePad by metadata.""" with open("original_test_file.txt", "w") as f: f.write("Some file with some content") @@ -87,7 +87,7 @@ def test_getfilesbyquerytask_run(self): assert test_file_contents == f.read().encode() os.remove(os.path.join(dest_dir, "queried_test_file.txt")) - def test_getfilesbyquerytask_metafile_run(self): + def test_getfilesbyquerytask_metafile_run(self) -> None: """Tests writing metadata to a yaml file.""" with open("original_test_file.txt", "w") as f: f.write("Some file with some content") @@ -113,7 +113,7 @@ def test_getfilesbyquerytask_metafile_run(self): os.remove(os.path.join(dest_dir, "queried_test_file.txt")) os.remove(os.path.join(dest_dir, "queried_test_file.txt.meta.yaml")) - def test_getfilesbyquerytask_ignore_empty_result_run(self): + def test_getfilesbyquerytask_ignore_empty_result_run(self) -> None: """Tests on ignoring empty results from FilePad query.""" dest_dir = os.path.abspath(".") t = GetFilesByQueryTask( @@ -125,7 +125,7 @@ def test_getfilesbyquerytask_ignore_empty_result_run(self): t.run_task({}) # test successful if no exception raised - def test_getfilesbyquerytask_raise_empty_result_run(self): + def test_getfilesbyquerytask_raise_empty_result_run(self) -> None: """Tests on raising exception on empty results from FilePad query.""" dest_dir = os.path.abspath(".") t = GetFilesByQueryTask( @@ -138,7 +138,7 @@ def test_getfilesbyquerytask_raise_empty_result_run(self): t.run_task({}) # test successful if exception raised - def test_getfilesbyquerytask_ignore_degenerate_file_name(self): + def test_getfilesbyquerytask_ignore_degenerate_file_name(self) -> None: """Tests on ignoring degenerate file name in result from FilePad query.""" with open("degenerate_file.txt", "w") as f: f.write("Some file with some content") @@ -158,7 +158,7 @@ def test_getfilesbyquerytask_ignore_degenerate_file_name(self): t.run_task({}) # test successful if no exception raised - def test_getfilesbyquerytask_raise_degenerate_file_name(self): + def test_getfilesbyquerytask_raise_degenerate_file_name(self) -> None: """Tests on raising exception on degenerate file name from FilePad query.""" with open("degenerate_file.txt", "w") as f: f.write("Some file with some content") @@ -179,7 +179,7 @@ def test_getfilesbyquerytask_raise_degenerate_file_name(self): t.run_task({}) # test successful if exception raised - def test_getfilesbyquerytask_sort_ascending_name_run(self): + def test_getfilesbyquerytask_sort_ascending_name_run(self) -> None: """Tests on sorting queried files in ascending order.""" file_contents = ["Some file with some content", "Some other file with some other content"] @@ -209,7 +209,7 @@ def test_getfilesbyquerytask_sort_ascending_name_run(self): with open("degenerate_file.txt") as f: assert file_contents[-1] == f.read() - def test_getfilesbyquerytask_sort_descending_name_run(self): + def test_getfilesbyquerytask_sort_descending_name_run(self) -> None: """Tests on sorting queried files in descending order.""" file_contents = ["Some file with some content", "Some other file with some other content"] @@ -244,7 +244,7 @@ def test_getfilesbyquerytask_sort_descending_name_run(self): os.remove("degenerate_file.txt") - def test_addfilesfrompatterntask_run(self): + def test_addfilesfrompatterntask_run(self) -> None: t = AddFilesTask(paths="*.yaml", directory=module_dir) t.run_task({}) write_file_contents, _ = self.fp.get_file(self.paths[0]) @@ -254,7 +254,7 @@ def test_addfilesfrompatterntask_run(self): with open(self.paths[1]) as f: assert del_file_contents == f.read().encode() - def tearDown(self): + def tearDown(self) -> None: self.fp.reset() diff --git a/fireworks/user_objects/firetasks/tests/test_script_task.py b/fireworks/user_objects/firetasks/tests/test_script_task.py index bc643df4d..f5d82c2b8 100644 --- a/fireworks/user_objects/firetasks/tests/test_script_task.py +++ b/fireworks/user_objects/firetasks/tests/test_script_task.py @@ -17,7 +17,7 @@ def afunc(y, z, a): class ScriptTaskTest(unittest.TestCase): - def test_scripttask(self): + def test_scripttask(self) -> None: if os.path.exists("hello.txt"): os.remove("hello.txt") s = ScriptTask({"script": 'echo "hello world"', "stdout_file": "hello.txt"}) @@ -30,7 +30,7 @@ def test_scripttask(self): class PyTaskTest(unittest.TestCase): - def test_task(self): + def test_task(self) -> None: p = PyTask(func="json.dumps", kwargs={"obj": {"hello": "world"}}, stored_data_varname="json") a = p.run_task({}) assert a.stored_data["json"] == '{"hello": "world"}' @@ -40,7 +40,7 @@ def test_task(self): p = PyTask(func="print", args=[3]) p.run_task({}) - def test_task_auto_kwargs(self): + def test_task_auto_kwargs(self) -> None: p = PyTask(func="json.dumps", obj={"hello": "world"}, stored_data_varname="json", auto_kwargs=True) a = p.run_task({}) assert a.stored_data["json"] == '{"hello": "world"}' @@ -50,7 +50,7 @@ def test_task_auto_kwargs(self): p = PyTask(func="print", args=[3]) p.run_task({}) - def test_task_data_flow(self): + def test_task_data_flow(self) -> None: """Test dataflow parameters: inputs, outputs and chunk_number.""" params = {"func": "pow", "inputs": ["arg", "power", "modulo"], "stored_data_varname": "data"} spec = {"arg": 2, "power": 3, "modulo": None} diff --git a/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py b/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py index 2cba8ef3b..6f2b40f20 100644 --- a/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py +++ b/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py @@ -13,7 +13,7 @@ class TemplateWriterTaskTest(unittest.TestCase): - def test_task(self): + def test_task(self) -> None: with open("test_template.txt", "w") as fp: fp.write("option1 = {{opt1}}\noption2 = {{opt2}}") t = TemplateWriterTask( diff --git a/fireworks/user_objects/firetasks/unittest_tasks.py b/fireworks/user_objects/firetasks/unittest_tasks.py index aee84717d..8c4b3b41f 100644 --- a/fireworks/user_objects/firetasks/unittest_tasks.py +++ b/fireworks/user_objects/firetasks/unittest_tasks.py @@ -12,7 +12,7 @@ class TestSerializer(FWSerializable): _fw_name = "TestSerializer Name" - def __init__(self, a, m_date): + def __init__(self, a, m_date) -> None: if not isinstance(m_date, datetime.datetime): raise ValueError("m_date must be a datetime instance!") @@ -34,7 +34,7 @@ def from_dict(cls, m_dict): class ExportTestSerializer(FWSerializable): _fw_name = "TestSerializer Export Name" - def __init__(self, a): + def __init__(self, a) -> None: self.a = a def __eq__(self, other): diff --git a/fireworks/user_objects/queue_adapters/common_adapter.py b/fireworks/user_objects/queue_adapters/common_adapter.py index 2d25e16d6..6e1380b25 100644 --- a/fireworks/user_objects/queue_adapters/common_adapter.py +++ b/fireworks/user_objects/queue_adapters/common_adapter.py @@ -39,7 +39,7 @@ class CommonAdapter(QueueAdapterBase): "MOAB": {"submit_cmd": "msub", "status_cmd": "showq"}, } - def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwargs): + def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwargs) -> None: """ Initializes a new QueueAdapter object. diff --git a/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py b/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py index 60f8cda51..d966af1c2 100644 --- a/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py +++ b/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py @@ -40,7 +40,7 @@ def get_njobs_in_queue(self, username=None): return len(r.json()) @staticmethod - def _init_auth_session(max_pw_requests=3): + def _init_auth_session(max_pw_requests=3) -> None: """ Initialize the _session class var with an authorized session. Asks for a / password in new sessions, skips PW check for previously authenticated sessions. diff --git a/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py b/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py index 38728a2ff..6f89c0cb5 100644 --- a/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py +++ b/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py @@ -16,7 +16,7 @@ class CommonAdapterTest(unittest.TestCase): - def test_serialization(self): + def test_serialization(self) -> None: p = CommonAdapter( q_type="PBS", q_name="hello", @@ -38,7 +38,7 @@ def test_serialization(self): assert p.get_script_str("here").split("\n")[-1] != "# world" assert "_fw_template_file" not in p.to_dict() - def test_yaml_load(self): + def test_yaml_load(self) -> None: # Test yaml loading. p = load_object_from_file(os.path.join(os.path.dirname(__file__), "pbs.yaml")) p = CommonAdapter(q_type="PBS", q_name="hello", ppnode="8:ib", nnodes=1, hello="world", queue="random") @@ -48,7 +48,7 @@ def test_yaml_load(self): yaml.dump(p.to_dict(), sys.stdout) print() - def test_parse_njobs(self): + def test_parse_njobs(self) -> None: pbs = """ tscc-mgr.sdsc.edu: Req'd Req'd Elap @@ -89,7 +89,7 @@ def test_parse_njobs(self): p = CommonAdapter(q_type="SGE", q_name="hello", queue="all.q", hello="world") assert p._parse_njobs(sge, "ongsp") == 3 - def test_parse_jobid(self): + def test_parse_jobid(self) -> None: p = CommonAdapter(q_type="SLURM", q_name="hello", queue="home-ong", hello="world") sbatch_output = """ SOME PREAMBLE @@ -107,14 +107,14 @@ def test_parse_jobid(self): qsub_output = 'Your job 44275 ("jobname") has been submitted' assert p._parse_jobid(qsub_output) == "44275" - def test_status_cmd_pbs(self): + def test_status_cmd_pbs(self) -> None: p = load_object_from_file( os.path.join(os.path.dirname(__file__), "pbs_override.yaml") # intentional red herring to test deepcopy ) p = CommonAdapter(q_type="PBS") assert p._get_status_cmd("my_name") == ["qstat", "-u", "my_name"] - def test_override(self): + def test_override(self) -> None: p = load_object_from_file(os.path.join(os.path.dirname(__file__), "pbs_override.yaml")) assert p._get_status_cmd("my_name") == ["my_qstatus", "-u", "my_name"] diff --git a/fireworks/utilities/dagflow.py b/fireworks/utilities/dagflow.py index a7b18a9ee..82c7cc7bf 100644 --- a/fireworks/utilities/dagflow.py +++ b/fireworks/utilities/dagflow.py @@ -43,7 +43,7 @@ class DAGFlow(Graph): visualization of workflows. """ - def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs): + def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs) -> None: Graph.__init__(self, directed=True, graph_attrs={"name": name}, **kwargs) for step in steps: @@ -136,14 +136,14 @@ def _get_ctrlflow_links(self): links.append((source, target)) return links - def _add_ctrlflow_links(self, links): + def _add_ctrlflow_links(self, links) -> None: """Adds graph edges corresponding to control flow links.""" for link in links: source = self._get_index(link[0]) target = self._get_index(link[1]) self.add_edge(source, target, label=" ") - def _add_dataflow_links(self, step_id=None, mode="both"): + def _add_dataflow_links(self, step_id=None, mode="both") -> None: """Adds graph edges corresponding to data flow links.""" if step_id: vidx = self._get_index(step_id) @@ -205,7 +205,7 @@ def _get_targets(self, step, entity): return lst @staticmethod - def _set_io_fields(step): + def _set_io_fields(step) -> None: """Set io keys as step attributes.""" for item in ["inputs", "outputs", "output"]: step[item] = [] @@ -266,22 +266,22 @@ def _get_leaves(self): """Returns all leaves (i.e. vertices without outgoing edges).""" return [i for i, v in enumerate(self.degree(mode=igraph.OUT)) if v == 0] - def delete_ctrlflow_links(self): + def delete_ctrlflow_links(self) -> None: """Deletes graph edges corresponding to control flow links.""" lst = [link.index for link in list(self.es) if link["label"] == " "] self.delete_edges(lst) - def delete_dataflow_links(self): + def delete_dataflow_links(self) -> None: """Deletes graph edges corresponding to data flow links.""" lst = [link.index for link in list(self.es) if link["label"] != " "] self.delete_edges(lst) - def add_step_labels(self): + def add_step_labels(self) -> None: """Labels the workflow steps (i.e. graph vertices).""" for vertex in list(self.vs): vertex["label"] = vertex["name"] + ", id: " + str(vertex["id"]) - def check(self): + def check(self) -> None: """Correctness check of the workflow.""" try: assert self.is_dag(), "The workflow graph must be a DAG." @@ -292,7 +292,7 @@ def check(self): assert len(self.vs["id"]) == len(set(self.vs["id"])), "Workflow steps must have unique IDs." self.check_dataflow() - def check_dataflow(self): + def check_dataflow(self) -> None: """Checks whether all inputs and outputs match.""" # check for shared output data entities for vertex in list(self.vs): @@ -323,7 +323,7 @@ def to_dict(self): dct["links"] = self._get_ctrlflow_links() return dct - def to_dot(self, filename="wf.dot", view="combined"): + def to_dot(self, filename="wf.dot", view="combined") -> None: """Writes the workflow into a file in DOT format.""" graph = DAGFlow(**self.to_dict()) if view == "controlflow": diff --git a/fireworks/utilities/dict_mods.py b/fireworks/utilities/dict_mods.py index 098a89bb1..afced3f50 100644 --- a/fireworks/utilities/dict_mods.py +++ b/fireworks/utilities/dict_mods.py @@ -70,26 +70,26 @@ class DictMods: supported using a special "->" keyword, e.g. {"a->b": 1} """ - def __init__(self): + def __init__(self) -> None: self.supported_actions = {} for i in dir(self): if (not re.match(r"__\w+__", i)) and callable(getattr(self, i)): self.supported_actions["_" + i] = getattr(self, i) @staticmethod - def set(input_dict, settings): + def set(input_dict, settings) -> None: for k, v in settings.items(): (d, key) = get_nested_dict(input_dict, k) d[key] = v @staticmethod - def unset(input_dict, settings): + def unset(input_dict, settings) -> None: for k in settings: (d, key) = get_nested_dict(input_dict, k) del d[key] @staticmethod - def push(input_dict, settings): + def push(input_dict, settings) -> None: for k, v in settings.items(): (d, key) = get_nested_dict(input_dict, k) if key in d: @@ -98,7 +98,7 @@ def push(input_dict, settings): d[key] = [v] @staticmethod - def push_all(input_dict, settings): + def push_all(input_dict, settings) -> None: for k, v in settings.items(): (d, key) = get_nested_dict(input_dict, k) if key in d: @@ -107,7 +107,7 @@ def push_all(input_dict, settings): d[key] = v @staticmethod - def inc(input_dict, settings): + def inc(input_dict, settings) -> None: for k, v in settings.items(): (d, key) = get_nested_dict(input_dict, k) if key in d: @@ -116,14 +116,14 @@ def inc(input_dict, settings): d[key] = v @staticmethod - def rename(input_dict, settings): + def rename(input_dict, settings) -> None: for k, v in settings.items(): if k in input_dict: input_dict[v] = input_dict[k] del input_dict[k] @staticmethod - def add_to_set(input_dict, settings): + def add_to_set(input_dict, settings) -> None: for k, v in settings.items(): (d, key) = get_nested_dict(input_dict, k) if key in d and (not isinstance(d[key], (list, tuple))): @@ -134,7 +134,7 @@ def add_to_set(input_dict, settings): d[key] = v @staticmethod - def pull(input_dict, settings): + def pull(input_dict, settings) -> None: for k, v in settings.items(): (d, key) = get_nested_dict(input_dict, k) if key in d and (not isinstance(d[key], (list, tuple))): @@ -143,7 +143,7 @@ def pull(input_dict, settings): d[key] = [i for i in d[key] if i != v] @staticmethod - def pull_all(input_dict, settings): + def pull_all(input_dict, settings) -> None: for k, v in settings.items(): if k in input_dict and (not isinstance(input_dict[k], (list, tuple))): raise ValueError(f"Keyword {k} does not refer to an array.") @@ -151,7 +151,7 @@ def pull_all(input_dict, settings): DictMods.pull(input_dict, {k: i}) @staticmethod - def pop(input_dict, settings): + def pop(input_dict, settings) -> None: for k, v in settings.items(): (d, key) = get_nested_dict(input_dict, k) if key in d and (not isinstance(d[key], (list, tuple))): @@ -162,7 +162,7 @@ def pop(input_dict, settings): d[key].pop(0) -def apply_mod(modification, obj): +def apply_mod(modification, obj) -> None: """ Note that modify makes actual in-place modifications. It does not return a copy. diff --git a/fireworks/utilities/filepad.py b/fireworks/utilities/filepad.py index e63e8ad85..ba62a32ee 100644 --- a/fireworks/utilities/filepad.py +++ b/fireworks/utilities/filepad.py @@ -36,7 +36,7 @@ def __init__( logdir=None, strm_lvl=None, text_mode=False, - ): + ) -> None: """ Args: host (str): hostname @@ -105,7 +105,7 @@ def __init__( # build indexes self.build_indexes() - def build_indexes(self, indexes=None, background=True): + def build_indexes(self, indexes=None, background=True) -> None: """ Build the indexes. @@ -193,7 +193,7 @@ def get_file_by_query(self, query, sort_key=None, sort_direction=pymongo.DESCEND cursor = self.filepad.find(query).sort(sort_key, sort_direction) return [self._get_file_contents(d) for d in cursor] - def delete_file(self, identifier): + def delete_file(self, identifier) -> None: """ Delete the document with the matching identifier. The contents in the gridfs as well as the associated document in the filepad are deleted. @@ -223,7 +223,7 @@ def update_file(self, identifier, path, compress=True): doc = self.filepad.find_one({"identifier": identifier}) return self._update_file_contents(doc, path, compress) - def delete_file_by_id(self, gfs_id): + def delete_file_by_id(self, gfs_id) -> None: """ Args: gfs_id (str): the file id. @@ -231,7 +231,7 @@ def delete_file_by_id(self, gfs_id): self.gridfs.delete(gfs_id) self.filepad.delete_one({"gfs_id": gfs_id}) - def delete_file_by_query(self, query): + def delete_file_by_query(self, query) -> None: """ Args: query (dict): pymongo query dict. @@ -367,7 +367,7 @@ def auto_load(cls): return FilePad.from_db_file(LAUNCHPAD_LOC) return FilePad() - def reset(self): + def reset(self) -> None: """Reset filepad and the gridfs collections.""" self.filepad.delete_many({}) self.db[self.gridfs_coll_name].files.delete_many({}) diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 6aa2b83b4..aec94e1f0 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -33,6 +33,7 @@ import json # note that ujson is faster, but at this time does not support "default" in dumps() import pkgutil import traceback +from typing import NoReturn from monty.json import MontyDecoder, MSONable from ruamel.yaml import YAML @@ -206,7 +207,7 @@ def fw_name(self): return get_default_serialization(self.__class__) @abc.abstractmethod - def to_dict(self): + def to_dict(self) -> NoReturn: raise NotImplementedError("FWSerializable object did not implement to_dict()!") def to_db_dict(self): @@ -220,10 +221,10 @@ def as_dict(self): @classmethod @abc.abstractmethod - def from_dict(cls, m_dict): + def from_dict(cls, m_dict) -> NoReturn: raise NotImplementedError("FWSerializable object did not implement from_dict()!") - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self.to_dict(), default=DATETIME_HANDLER) def to_format(self, f_format="json", **kwargs): @@ -266,7 +267,7 @@ def from_format(cls, f_str, f_format="json"): fireworks_schema.validate(dct, cls.__name__) return cls.from_dict(reconstitute_dates(dct)) - def to_file(self, filename, f_format=None, **kwargs): + def to_file(self, filename, f_format=None, **kwargs) -> None: """ Write a serialization of this object to a file. diff --git a/fireworks/utilities/fw_utilities.py b/fireworks/utilities/fw_utilities.py index 61519a7ac..3b63ce2ae 100644 --- a/fireworks/utilities/fw_utilities.py +++ b/fireworks/utilities/fw_utilities.py @@ -72,7 +72,7 @@ def get_fw_logger( return logger -def log_multi(m_logger, msg, log_lvl="info"): +def log_multi(m_logger, msg, log_lvl="info") -> None: """ Args: m_logger (logger): The logger object @@ -86,7 +86,7 @@ def log_multi(m_logger, msg, log_lvl="info"): _log_fnc(msg) -def log_fancy(m_logger, msgs, log_lvl="info", add_traceback=False): +def log_fancy(m_logger, msgs, log_lvl="info", add_traceback=False) -> None: """ A wrapper around the logger messages useful for multi-line logs. Helps to group log messages by adding a fancy border around it, diff --git a/fireworks/utilities/tests/test_dagflow.py b/fireworks/utilities/tests/test_dagflow.py index b4b8e63eb..a2b94d551 100644 --- a/fireworks/utilities/tests/test_dagflow.py +++ b/fireworks/utilities/tests/test_dagflow.py @@ -16,7 +16,7 @@ class DAGFlowTest(unittest.TestCase): """run tests for DAGFlow class.""" - def setUp(self): + def setUp(self) -> None: try: __import__("igraph", fromlist=["Graph"]) except (ImportError, ModuleNotFoundError): @@ -34,7 +34,7 @@ def setUp(self): ) self.fw3 = Firework(PyTask(func="print", inputs=["second power"]), name="the third one") - def test_dagflow_ok(self): + def test_dagflow_ok(self) -> None: """Construct and replicate.""" from fireworks.utilities.dagflow import DAGFlow @@ -42,7 +42,7 @@ def test_dagflow_ok(self): dagf = DAGFlow.from_fireworks(wfl) DAGFlow(**dagf.to_dict()) - def test_dagflow_loop(self): + def test_dagflow_loop(self) -> None: """Loop in graph.""" from fireworks.utilities.dagflow import DAGFlow @@ -52,7 +52,7 @@ def test_dagflow_loop(self): DAGFlow.from_fireworks(wfl).check() assert msg in str(exc.value) - def test_dagflow_cut(self): + def test_dagflow_cut(self) -> None: """Disconnected graph.""" from fireworks.utilities.dagflow import DAGFlow @@ -62,7 +62,7 @@ def test_dagflow_cut(self): DAGFlow.from_fireworks(wfl).check() assert msg in str(exc.value) - def test_dagflow_link(self): + def test_dagflow_link(self) -> None: """Wrong links.""" from fireworks.utilities.dagflow import DAGFlow @@ -72,7 +72,7 @@ def test_dagflow_link(self): DAGFlow.from_fireworks(wfl).check() assert msg in str(exc.value) - def test_dagflow_missing_input(self): + def test_dagflow_missing_input(self) -> None: """Missing input.""" from fireworks.utilities.dagflow import DAGFlow @@ -89,7 +89,7 @@ def test_dagflow_missing_input(self): DAGFlow.from_fireworks(wfl).check() assert msg in str(exc.value) - def test_dagflow_clashing_inputs(self): + def test_dagflow_clashing_inputs(self) -> None: """Parent firework output overwrites an input in spec.""" from fireworks.utilities.dagflow import DAGFlow @@ -107,7 +107,7 @@ def test_dagflow_clashing_inputs(self): DAGFlow.from_fireworks(wfl).check() assert msg in str(exc.value) - def test_dagflow_race_condition(self): + def test_dagflow_race_condition(self) -> None: """Two parent firework outputs overwrite each other.""" from fireworks.utilities.dagflow import DAGFlow @@ -123,7 +123,7 @@ def test_dagflow_race_condition(self): DAGFlow.from_fireworks(wfl).check() assert msg in str(exc.value) - def test_dagflow_clashing_outputs(self): + def test_dagflow_clashing_outputs(self) -> None: """Subsequent task overwrites output of a task.""" from fireworks.utilities.dagflow import DAGFlow @@ -137,7 +137,7 @@ def test_dagflow_clashing_outputs(self): DAGFlow.from_fireworks(Workflow([fwk], {})).check() assert msg in str(exc.value) - def test_dagflow_non_dataflow_tasks(self): + def test_dagflow_non_dataflow_tasks(self) -> None: """non-dataflow tasks using outputs and inputs keys do not fail.""" from fireworks.core.firework import FiretaskBase from fireworks.utilities.dagflow import DAGFlow @@ -148,7 +148,7 @@ class NonDataFlowTask(FiretaskBase): _fw_name = "NonDataFlowTask" required_params = ["inputs", "outputs"] - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: pass task = NonDataFlowTask(inputs=["first power", "exponent"], outputs=["second power"]) @@ -156,7 +156,7 @@ def run_task(self, fw_spec): wfl = Workflow([self.fw1, fw2], {self.fw1: [fw2], fw2: []}) DAGFlow.from_fireworks(wfl).check() - def test_dagflow_view(self): + def test_dagflow_view(self) -> None: """Visualize the workflow graph.""" from fireworks.utilities.dagflow import DAGFlow diff --git a/fireworks/utilities/tests/test_filepad.py b/fireworks/utilities/tests/test_filepad.py index f501d7124..f2d628f1e 100644 --- a/fireworks/utilities/tests/test_filepad.py +++ b/fireworks/utilities/tests/test_filepad.py @@ -7,22 +7,22 @@ class FilePadTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.chgcar_file = os.path.join(module_dir, "CHGCAR.Fe3O4") self.fp = FilePad.auto_load() self.identifier = "Fe3O4" - def test_add_file(self): + def test_add_file(self) -> None: gfs_id, file_identifier = self.fp.add_file(self.chgcar_file, identifier=self.identifier) assert file_identifier == self.identifier assert gfs_id is not None - def test_add_file_with_no_identifier(self): + def test_add_file_with_no_identifier(self) -> None: gfs_id, file_identifier = self.fp.add_file(self.chgcar_file) assert gfs_id is not None assert file_identifier == gfs_id - def test_get_file(self): + def test_get_file(self) -> None: _, file_identifier = self.fp.add_file(self.chgcar_file, identifier="xxx", metadata={"author": "Kiran Mathew"}) file_contents, doc = self.fp.get_file(file_identifier) with open(self.chgcar_file) as file: @@ -35,27 +35,27 @@ def test_get_file(self): assert doc["original_file_path"] == abspath assert doc["compressed"] is True - def test_delete_file(self): + def test_delete_file(self) -> None: _, file_identifier = self.fp.add_file(self.chgcar_file) self.fp.delete_file(file_identifier) contents, doc = self.fp.get_file(file_identifier) assert contents is None assert doc is None - def test_update_file(self): + def test_update_file(self) -> None: gfs_id, _ = self.fp.add_file(self.chgcar_file, identifier="test_update_file") old_id, new_id = self.fp.update_file("test_update_file", self.chgcar_file) assert old_id == gfs_id assert new_id != gfs_id assert not self.fp.gridfs.exists(old_id) - def test_update_file_by_id(self): + def test_update_file_by_id(self) -> None: gfs_id, _ = self.fp.add_file(self.chgcar_file, identifier="some identifier") old, new = self.fp.update_file_by_id(gfs_id, self.chgcar_file) assert old == gfs_id assert new != gfs_id - def tearDown(self): + def tearDown(self) -> None: self.fp.reset() diff --git a/fireworks/utilities/tests/test_fw_serializers.py b/fireworks/utilities/tests/test_fw_serializers.py index 6b6ae81ff..9d447d14e 100644 --- a/fireworks/utilities/tests/test_fw_serializers.py +++ b/fireworks/utilities/tests/test_fw_serializers.py @@ -20,7 +20,7 @@ @explicit_serialize class ExplicitTestSerializer(FWSerializable): - def __init__(self, a): + def __init__(self, a) -> None: self.a = a def __eq__(self, other): @@ -35,7 +35,7 @@ def from_dict(cls, m_dict): class SerializationTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: test_date = datetime.datetime.utcnow() # A basic datetime test serialized object self.obj_1 = TestSerializer("prop1", test_date) @@ -55,50 +55,50 @@ def setUp(self): self.module_dir = os.path.dirname(os.path.abspath(__file__)) - def tearDown(self): + def tearDown(self) -> None: os.remove("test.json") os.remove("test.yaml") - def test_sanity(self): + def test_sanity(self) -> None: assert self.obj_1 == self.obj_1_copy, "The __eq__() method of the TestSerializer is not set up properly!" assert self.obj_1 != self.obj_2, "The __ne__() method of the TestSerializer is not set up properly!" assert self.obj_1 == self.obj_1.from_dict( self.obj_1.to_dict() ), "The to/from_dict() methods of the TestSerializer are not set up properly!" - def test_serialize_fw_decorator(self): + def test_serialize_fw_decorator(self) -> None: m_dict = self.obj_1.to_dict() assert m_dict["_fw_name"] == "TestSerializer Name" - def test_json(self): + def test_json(self) -> None: obj1_json_string = str(self.obj_1.to_format()) # default format is JSON, make sure this is true assert self.obj_1.from_format(obj1_json_string) == self.obj_1, "JSON format export / import fails!" - def test_yaml(self): + def test_yaml(self) -> None: obj1_yaml_string = str(self.obj_1.to_format("yaml")) assert self.obj_1.from_format(obj1_yaml_string, "yaml") == self.obj_1, "YAML format export / import fails!" - def test_complex_json(self): + def test_complex_json(self) -> None: obj2_json_string = str(self.obj_2.to_format()) # default format is JSON, make sure this is true assert self.obj_2.from_format(obj2_json_string) == self.obj_2, "Complex JSON format export / import fails!" - def test_complex_yaml(self): + def test_complex_yaml(self) -> None: obj2_yaml_string = str(self.obj_2.to_format("yaml")) assert ( self.obj_2.from_format(obj2_yaml_string, "yaml") == self.obj_2 ), "Complex YAML format export / import fails!" - def test_unicode_json(self): + def test_unicode_json(self) -> None: obj3_json_string = str(self.obj_3.to_format()) # default format is JSON, make sure this is true assert self.obj_3.from_format(obj3_json_string) == self.obj_3, "Unicode JSON format export / import fails!" - def test_unicode_yaml(self): + def test_unicode_yaml(self) -> None: obj3_yaml_string = str(self.obj_3.to_format("yaml")) assert ( self.obj_3.from_format(obj3_yaml_string, "yaml") == self.obj_3 ), "Unicode YAML format export / import fails!" - def test_unicode_json_file(self): + def test_unicode_json_file(self) -> None: with open(os.path.join(self.module_dir, "test_reference.json")) as f, open( "test.json", **ENCODING_PARAMS ) as f2: @@ -108,22 +108,22 @@ def test_unicode_json_file(self): assert self.obj_3.from_file("test.json") == self.obj_3, "Unicode JSON file import fails!" - def test_unicode_yaml_file(self): + def test_unicode_yaml_file(self) -> None: ref_path = os.path.join(self.module_dir, "test_reference.yaml") with open(ref_path, **ENCODING_PARAMS) as f, open("test.yaml", **ENCODING_PARAMS) as f2: assert f.read() == f2.read(), "Unicode JSON file export fails" assert self.obj_3.from_file("test.yaml") == self.obj_3, "Unicode YAML file import fails!" - def test_implicit_serialization(self): + def test_implicit_serialization(self) -> None: assert ( load_object({"a": {"p1": {"p2": 3}}, "_fw_name": "TestSerializer Export Name"}) == self.obj_4 ), "Implicit import fails!" - def test_as_dict(self): + def test_as_dict(self) -> None: assert self.obj_1.as_dict() == self.obj_1.to_dict() - def test_numpy_array(self): + def test_numpy_array(self) -> None: try: import numpy as np except Exception: @@ -135,11 +135,11 @@ def test_numpy_array(self): class ExplicitSerializationTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.s_obj = ExplicitTestSerializer(1) self.s_dict = self.s_obj.to_dict() - def test_explicit_serialization(self): + def test_explicit_serialization(self) -> None: assert load_object(self.s_dict) == self.s_obj diff --git a/fireworks/utilities/tests/test_update_collection.py b/fireworks/utilities/tests/test_update_collection.py index 349e6b605..620694bf1 100644 --- a/fireworks/utilities/tests/test_update_collection.py +++ b/fireworks/utilities/tests/test_update_collection.py @@ -15,7 +15,7 @@ class UpdateCollectionTests(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.lp = None try: cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") @@ -24,11 +24,11 @@ def setUpClass(cls): raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if cls.lp: cls.lp.connection.drop_database(TESTDB_NAME) - def test_update_path(self): + def test_update_path(self) -> None: self.lp.db.test_coll.insert_one({"foo": "bar", "foo_list": [{"foo1": "bar1"}, {"foo2": "foo/old/path/bar"}]}) update_path_in_collection( self.lp.db, diff --git a/fireworks/utilities/tests/test_visualize.py b/fireworks/utilities/tests/test_visualize.py index 330d334dd..1d82a01ac 100644 --- a/fireworks/utilities/tests/test_visualize.py +++ b/fireworks/utilities/tests/test_visualize.py @@ -21,7 +21,7 @@ def power_wf(): return Workflow([fw1, fw2, fw3], {fw1: [fw2], fw2: [fw3], fw3: []}) -def test_wf_to_graph(power_wf): +def test_wf_to_graph(power_wf) -> None: dag = wf_to_graph(power_wf) assert isinstance(dag, Digraph) @@ -31,7 +31,7 @@ def test_wf_to_graph(power_wf): assert isinstance(dag, Digraph) -def test_plot_wf(power_wf): +def test_plot_wf(power_wf) -> None: plot_wf(power_wf) plot_wf(power_wf, depth_factor=0.5, breadth_factor=1) diff --git a/fireworks/utilities/update_collection.py b/fireworks/utilities/update_collection.py index 9d63f28ac..25070376b 100644 --- a/fireworks/utilities/update_collection.py +++ b/fireworks/utilities/update_collection.py @@ -8,7 +8,7 @@ __date__ = "Dec 08, 2016" -def update_launchpad_data(lp, replacements, **kwargs): +def update_launchpad_data(lp, replacements, **kwargs) -> None: """ If you want to update a text string in your entire FireWorks database with a replacement, use this method. For example, you might want to update a directory name preamble like "/scratch/user1" to "/project/user2". @@ -26,7 +26,7 @@ def update_launchpad_data(lp, replacements, **kwargs): print("Update launchpad data complete.") -def update_path_in_collection(db, collection_name, replacements, query=None, dry_run=False, force_clear=False): +def update_path_in_collection(db, collection_name, replacements, query=None, dry_run=False, force_clear=False) -> None: """ updates the text specified in replacements for the documents in a MongoDB collection. This can be used to mass-update an outdated value (e.g., a directory path or tag) in that collection. diff --git a/fireworks/utilities/visualize.py b/fireworks/utilities/visualize.py index f1e04dc85..0c079c112 100644 --- a/fireworks/utilities/visualize.py +++ b/fireworks/utilities/visualize.py @@ -25,7 +25,7 @@ def plot_wf( markersize=10, markerfacecolor="blue", fontsize=12, -): +) -> None: """ Generate a visual representation of the workflow. Useful for checking whether the firework connections are in order before launching the workflow. diff --git a/fw_tutorials/dynamic_wf/printjob_task.py b/fw_tutorials/dynamic_wf/printjob_task.py index 0427d1907..2081402f1 100644 --- a/fw_tutorials/dynamic_wf/printjob_task.py +++ b/fw_tutorials/dynamic_wf/printjob_task.py @@ -10,7 +10,7 @@ class PrintJobTask(FiretaskBase): _fw_name = "Print Job Task" - def run_task(self, fw_spec): + def run_task(self, fw_spec) -> None: job_info_array = fw_spec["_job_info"] prev_job_info = job_info_array[-1] diff --git a/fw_tutorials/python/python_examples.py b/fw_tutorials/python/python_examples.py index 9a11d2ba0..b59bc3134 100644 --- a/fw_tutorials/python/python_examples.py +++ b/fw_tutorials/python/python_examples.py @@ -19,7 +19,7 @@ def setup(): return launchpad -def basic_fw_ex(): +def basic_fw_ex() -> None: print("--- BASIC FIREWORK EXAMPLE ---") # setup @@ -34,7 +34,7 @@ def basic_fw_ex(): launch_rocket(launchpad, FWorker()) -def rapid_fire_ex(): +def rapid_fire_ex() -> None: print("--- RAPIDFIRE EXAMPLE ---") # setup @@ -55,7 +55,7 @@ def rapid_fire_ex(): rapidfire(launchpad, FWorker()) -def multiple_tasks_ex(): +def multiple_tasks_ex() -> None: print("--- MULTIPLE FIRETASKS EXAMPLE ---") # setup @@ -72,7 +72,7 @@ def multiple_tasks_ex(): rapidfire(launchpad, FWorker()) -def basic_wf_ex(): +def basic_wf_ex() -> None: print("--- BASIC WORKFLOW EXAMPLE ---") # setup diff --git a/tasks.py b/tasks.py index a36980eaa..bbdc164ed 100644 --- a/tasks.py +++ b/tasks.py @@ -21,7 +21,7 @@ @task -def make_doc(ctx): +def make_doc(ctx) -> None: with cd("docs_rst"): ctx.run("sphinx-apidoc -o . -f ../fireworks") ctx.run("make html") @@ -36,7 +36,7 @@ def make_doc(ctx): @task -def update_doc(ctx): +def update_doc(ctx) -> None: make_doc(ctx) with cd("docs"): ctx.run("git add .") @@ -45,12 +45,12 @@ def update_doc(ctx): @task -def publish(ctx): +def publish(ctx) -> None: ctx.run("python setup.py release") @task -def release_github(ctx): +def release_github(ctx) -> None: payload = { "tag_name": fw_version, "target_commitish": "master", @@ -71,13 +71,13 @@ def release_github(ctx): @task -def release(ctx): +def release(ctx) -> None: publish(ctx) update_doc(ctx) release_github(ctx) @task -def open_doc(ctx): +def open_doc(ctx) -> None: pth = os.path.abspath("docs/index.html") webbrowser.open("file://" + pth)