diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 5a559818b..fef85ff95 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -11,6 +11,7 @@ from collections import defaultdict from itertools import chain + import gridfs from bson import ObjectId from monty.os.path import zpath @@ -20,6 +21,8 @@ from tqdm import tqdm from fireworks.core.firework import Firework, FWAction, Launch, Tracker, Workflow +from fireworks.utilities import fw_id_from_reservation_id +from fireworks.utilities import reservation_id_from_fw_id from fireworks.fw_config import MongoClient from fireworks.fw_config import ( GRIDFS_FALLBACK_COLLECTION, @@ -1192,18 +1195,17 @@ def reserve_fw(self, fworker, launch_dir, host=None, ip=None, fw_id=None): """ return self.checkout_fw(fworker, launch_dir, host=host, ip=ip, fw_id=fw_id, state="RESERVED") - def get_fw_ids_from_reservation_id(self, reservation_id): + def get_fw_id_from_reservation_id(self, reservation_id): """ Given the reservation id, return the list of firework ids. - Args: reservation_id (int) - Returns: - [int]: list of firework ids. + [int]: Return the firework id. """ - l_id = self.launches.find_one({"state_history.reservation_id": reservation_id}, {"launch_id": 1})["launch_id"] - return [fw["fw_id"] for fw in self.fireworks.find({"launches": l_id}, {"fw_id": 1})] + fw_id=fw_id_from_reservation_id.get_fwid(reservation_id) + + return fw_id def cancel_reservation_by_reservation_id(self, reservation_id) -> None: """Given the reservation id, cancel the reservation and rerun the corresponding fireworks.""" @@ -1217,14 +1219,10 @@ def cancel_reservation_by_reservation_id(self, reservation_id) -> None: def get_reservation_id_from_fw_id(self, fw_id): """Given the firework id, return the reservation id.""" - fw = self.fireworks.find_one({"fw_id": fw_id}, {"launches": 1}) - if fw: - for launch in self.launches.find({"launch_id": {"$in": fw["launches"]}}, {"state_history": 1}): - for d in launch["state_history"]: - if "reservation_id" in d: - return d["reservation_id"] - return None - return None + jobid=reservation_id_from_fw_id.main(fw_id) + if jobid==None: + print('No matching fw_id-JobID pair. The firework may be a lost run') + return jobid def cancel_reservation(self, launch_id) -> None: """Given the launch id, cancel the reservation and rerun the fireworks.""" diff --git a/fireworks/utilities/fw_id_from_reservation_id.py b/fireworks/utilities/fw_id_from_reservation_id.py new file mode 100644 index 000000000..7fe4c139f --- /dev/null +++ b/fireworks/utilities/fw_id_from_reservation_id.py @@ -0,0 +1,169 @@ +import os +import subprocess +import json +import paramiko +import getpass +import re + +# Function to execute local shell commands and return the output +def execute_command(command): + try: + result = subprocess.run(command, shell=True, capture_output=True, text=True) + if result.returncode != 0: + #raise Exception(f"Command failed: {command}\n{result.stderr}") + raise Exception('Running fireworks locally') + ssh=None + return result.stdout.strip(),ssh + except Exception as e: + result,ssh=ssh_login(command) + print(e) + return result,ssh + + +def extract_username_hostname(input_string): + # Define the regex pattern + pattern = r'(?P[^@]+)@(?P.+)' + + # Search for the pattern in the input string + match = re.match(pattern, input_string) + + if match: + # Extract username and hostname from named groups + username = match.group('username') + hostname = match.group('hostname') + return username, hostname + else: + raise ValueError("The input does not match the required format 'username@hostname'.") + +# Get user input + +# SSH login and execute remote command +def ssh_login(command): + input_string = input("Enter username@hostname: ").strip() + username, hostname = extract_username_hostname(input_string) + password = getpass.getpass('Enter password+OTP: ') + + # Create an SSH client + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + # Connect to the server + ssh.connect(hostname, username=username, password=password) + # Execute the command + stdin, stdout, stderr = ssh.exec_command(command) + output = stdout.read().decode('utf-8').strip() + errors = stderr.read().decode('utf-8').strip() + + if errors: + raise Exception(f"Command failed: {command}\n{errors}") + + except Exception as e: + print(e) + return output, ssh + + +def get_fwid(jobid): + job_info,ssh = execute_command(f"scontrol show jobid {jobid}") + if ssh !=None: + fw_id=find_worker(job_info,ssh) + ssh.close() + else: + fw_id=find_worker(job_info,ssh) + + return fw_id + + +def find_worker(job_info, ssh): + stdout_dir = "" + for line in job_info.splitlines(): + if "StdOut=" in line: + stdout_dir = line.split("=", 1)[1] + break + + if not stdout_dir: + print("StdOut path not found in job information") + return + + base_dir = os.path.dirname(stdout_dir) + + if ssh!=None: + # Change directory to the base directory on the remote server + stdin, stdout, stderr = ssh.exec_command(f"cd {base_dir} && pwd") + current_dir = stdout.read().decode('utf-8').strip() + errors = stderr.read().decode('utf-8').strip() + if errors: + raise Exception(f"Failed to change directory: {errors}") + + print(f"Changed directory to: {current_dir}") + + stdin, stdout, stderr = ssh.exec_command(f"find {current_dir} -type d -name 'launcher_*'") + launch_dirs = stdout.read().decode('utf-8').splitlines() + errors = stderr.read().decode('utf-8').strip() + if errors: + raise Exception(f"Failed to find launch directories: {errors}") + + largest_dir = max(launch_dirs, key=lambda d: d.split('_')[-1]) + + # Change to the largest directory + stdin, stdout, stderr = ssh.exec_command(f"cd {largest_dir} && pwd") + final_dir = stdout.read().decode('utf-8').strip() + errors = stderr.read().decode('utf-8').strip() + if errors: + raise Exception(f"Failed to change directory to {largest_dir}: {errors}") + + print(f"Changed directory to: {final_dir}") + + # Check for the JSON file in the directory + stdin, stdout, stderr = ssh.exec_command(f"cat {final_dir}/FW.json") + json_data = stdout.read().decode('utf-8').strip() + errors = stderr.read().decode('utf-8').strip() + if errors: + raise Exception(f"Failed to read FW.json: {errors}") + + data = json.loads(json_data) + spec_mpid = data.get('spec', {}).get('MPID', 'N/A') + fw_id = data.get('fw_id', 'N/A') + + print(f"spec.MPID: {spec_mpid}") + print(f"fw_id: {fw_id}") + else: + # Change directory to the extracted base directory + try: + os.chdir(base_dir) + except OSError: + print(f"Failed to change directory to {base_dir}") + exit(1) + + # Print the current directory to confirm + print(f"Changed directory to: {os.getcwd()}") + + # Find the largest directory with the pattern "launcher_*" + launch_dirs = subprocess.check_output(f"find {os.getcwd()} -type d -name 'launcher_*'", shell=True).decode().splitlines() + largest_dir = max(launch_dirs, key=lambda d: d.split('_')[-1]) + + try: + os.chdir(largest_dir) + except OSError: + print(f"Failed to change directory to {largest_dir}") + exit(1) + + print(f"Changed directory to: {os.getcwd()}") + + json_file = os.path.join(os.getcwd(), "FW.json") + + # Check if the JSON file exists + if os.path.isfile(json_file): + with open(json_file, 'r') as f: + data = json.load(f) + spec_mpid = data.get('spec', {}).get('MPID', 'N/A') + fw_id = data.get('fw_id', 'N/A') + + # Output the extracted values + print(f"spec.MPID: {spec_mpid}") + print(f"fw_id: {fw_id}") + else: + print(f"FW.json not found in {largest_dir}") + + return fw_id + return fw_id \ No newline at end of file diff --git a/fireworks/utilities/reservation_id_from_fw_id.py b/fireworks/utilities/reservation_id_from_fw_id.py new file mode 100644 index 000000000..df81ad2f8 --- /dev/null +++ b/fireworks/utilities/reservation_id_from_fw_id.py @@ -0,0 +1,206 @@ +import os +import subprocess +import re +import json +import sys +import getpass +import paramiko + +# ANSI color codes for terminal output +RED = '\033[0;31m' +CYAN = '\033[0;36m' +ORANGE = '\033[0;33m' +NC = '\033[0m' # No Color + +def run_command(command, ssh=None): + if ssh: + stdin, stdout, stderr = ssh.exec_command(command) + output = stdout.read().decode('utf-8').strip() + errors = stderr.read().decode('utf-8').strip() + if errors: + raise Exception(f"Command failed: {command}\n{errors}") + return output + else: + result = subprocess.run(command, shell=True, capture_output=True, text=True) + if result.returncode != 0: + #raise Exception(f"Command failed: {command}\n{result.stderr}") + raise Exception('Running locally') + return result.stdout.strip() + +def execute_command(command): + try: + return run_command(command), None + except Exception as e: + print(e) + return ssh_login(command) + +def extract_username_hostname(input_string): + pattern = r'(?P[^@]+)@(?P.+)' + match = re.match(pattern, input_string) + if match: + return match.group('username'), match.group('hostname') + else: + raise ValueError("The input does not match the required format 'username@hostname'.") + +def ssh_login(command): + input_string = input("Enter username@hostname: ").strip() + username, hostname = extract_username_hostname(input_string) + password = getpass.getpass('Enter password+OTP: ') + + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(hostname, username=username, password=password) + + return run_command(command, ssh), ssh + +def get_stdout_dir(job_id, ssh=None): + command = f"scontrol show jobid {job_id}" + output = run_command(command, ssh) + match = re.search(r"StdOut=(\S+)", output) + if match: + return match.group(1) + else: + print(f"{RED}StdOut path not found in job information{NC}") + sys.exit(1) + +def load_json(file_path, ssh=None): + if ssh: + command = f"cat {file_path}" + json_data = run_command(command, ssh) + return json.loads(json_data) + else: + with open(file_path) as f: + return json.load(f) + +def process_all_jobs(job_list,ssh=None): + fw_dict = {} + #user = os.getenv('USER') + #job_list = run_command(f"squeue --states=R -u {user}", ssh) + job_lines = job_list.splitlines()[1:] # Skip header + if not job_lines: + print(f"{RED}No jobs found!{NC}") + sys.exit(1) + load_length=len(job_lines) + for i,line in enumerate(job_lines): + i=i+1 + job_id = line.split()[0] + percent_complete = (i / load_length) * 100 + + # Create the bar part of the output + bar = '#' * i + '-' * (load_length - i) + + # Display the loading bar along with the percentage + sys.stdout.write(f'\r[{bar}] {percent_complete:.2f}%') + sys.stdout.flush() + print("\n") + #print(f"{ORANGE}Processing job ID: {CYAN}{job_id}{NC}") + fw_id = process_single_job(job_id, ssh) + fw_dict[fw_id] = job_id + + + return fw_dict + +def process_single_job(job_id, ssh=None): + stdout_dir = get_stdout_dir(job_id, ssh) + base_dir = os.path.dirname(stdout_dir) + return dir_rapidfire(base_dir, ssh) + +def dir_singleshot(base_dir, ssh=None): + json_file = os.path.join(base_dir, "FW.json") + if ssh: + command = f"test -f {json_file} && cat {json_file}" + try: + data = run_command(command, ssh) + data = json.loads(data) + except: + print_warning(base_dir) + return 1 + else: + if os.path.isfile(json_file): + data = load_json(json_file) + else: + print_warning(base_dir) + return 1 + + spec_mpid = data.get('spec', {}).get('MPID') + fw_id = data.get('fw_id') + if spec_mpid: + print(f"spec.MPID: {spec_mpid}") + if fw_id: + print(f"fw_id: {fw_id}") + return fw_id if fw_id else 1 + +def dir_rapidfire(base_dir, ssh=None): + if ssh: + command = f"cd {base_dir} && ls -d launcher_*" + try: + launcher_dirs = run_command(command, ssh).split() + except: + return dir_singleshot(base_dir, ssh) + else: + launcher_dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("launcher_")] + + if not launcher_dirs: + return dir_singleshot(base_dir, ssh) + + largest_dir = sorted(launcher_dirs, reverse=True)[0] + launcher_path = os.path.join(base_dir, largest_dir) + json_file = os.path.join(launcher_path, "FW.json") + + if ssh: + command = f"test -f {json_file} && cat {json_file}" + try: + data = run_command(command, ssh) + data = json.loads(data) + except: + print_warning(base_dir) + return dir_singleshot(base_dir, ssh) + else: + if os.path.isfile(json_file): + data = load_json(json_file) + else: + print_warning(base_dir) + return dir_singleshot(base_dir) + + spec_mpid = data.get('spec', {}).get('MPID') + fw_id = data.get('fw_id') + #print(f"spec.MPID: {spec_mpid}") + #print(f"fw_id: {fw_id}") + return fw_id + +def print_warning(dir_path): + warning_message = f""" + {"-" * 77} + | | + | W W AA RRRRR N N II N N GGGG !!! | + | W W A A R R NN N II NN N G G !!! | + | W W A A R R N N N II N N N G !!! | + | W WW W AAAAAA RRRRR N N N II N N N G GGG ! | + | WW WW A A R R N NN II N NN G G | + | W W A A R R N N II N N GGGG !!! | + | | + | This slurm job probably doesn't have an FW_ID associated with it. | + | something probably went wrong. You can probably check the directory | + | above to maybe figure out what happened. Best of luck | + | Also why are you using singleshot | + | | + {"-" * 77} + """ + print(warning_message) + +def main(fw_id): + current_dir = os.getcwd() + fw_dict = {} + + command = "squeue --states=R -u ${USER}" + result, ssh = execute_command(command) + #print(result) + + fw_dict = process_all_jobs(result,ssh) + + os.chdir(current_dir) + jobid=fw_dict.get(fw_id) + return jobid + +#if __name__ == "__main__": + # main() diff --git a/setup.py b/setup.py index 554530e9e..6c9b434d0 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ "tabulate>=0.7.5", "flask>=0.11.1", "flask-paginate>=0.4.5", + "paramiko>=3.4.0", "gunicorn>=19.6.0", "tqdm>=4.8.4", "importlib-metadata>=4.8.2; python_version<'3.8'",