diff --git a/jwql/instrument_monitors/nircam_monitors/wisp_finder.py b/jwql/instrument_monitors/nircam_monitors/wisp_finder.py index 5fde1eb0b..99aad623e 100755 --- a/jwql/instrument_monitors/nircam_monitors/wisp_finder.py +++ b/jwql/instrument_monitors/nircam_monitors/wisp_finder.py @@ -49,15 +49,22 @@ from torchvision import transforms import torchvision.models as models +from jwql.shared_tasks.shared_tasks import only_one +from jwql.utils import monitor_utils +from jwql.utils.constants import ON_GITHUB_ACTIONS, ON_READTHEDOCS +from jwql.utils.logging_functions import log_info, log_fail from jwql.utils.utils import get_config from jwql.website.apps.jwql.archive_database_update import files_in_filesystem from jwql.website.apps.jwql.models import Anomalies, RootFileInfo from jwql.instrument_monitors.nircam_monitors import prepare_wisp_pngs -if 1>0: +if not ON_GITHUB_ACTIONS and not ON_READTHEDOCS: os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jwql.website.jwql_proj.settings") setup() + # once we have db models defined, uncoment this line + #from jwql.website.apps.jwql.monitor_models.wisp_finder import * + def add_wisp_flag(basename): """Add the wisps flag to the RootFileInfo entry for the given filename @@ -99,22 +106,30 @@ def copy_files_to_working_dir(filepaths): ---------- filepaths : list List of full paths of files to be copied + + Returns + ------- + copied_filepaths : list + List of new locations for the files """ working_dir = get_config()["working"] copied_filepaths = [] for filepath in filepaths: shutil.copy2(filepath, working_dir) copied_filepaths.append(os.path.join(working_dir, os.path.basename(filepath))) - - - print(f'copying {filepath} to {working_dir}') - + logging.info(f'Copying {filepath} to {working_dir}') return copied_filepaths def create_transform(): - """ + """Create a transform function that will be used to modify images + and place them in the format expected by the ML model + + Returns + ------- + transform : torchvision.transforms.transforms.Compose + Image transform model """ transform = transforms.Compose([ transforms.Resize((128, 128)), # Resize images to a fixed size @@ -125,7 +140,13 @@ def create_transform(): def define_model_architecture(): - """ + """Define the basic architecture of the ML model. This will be the framework into which + the model parameters will be loaded, in order to fully define the function. + + Returns + ------- + model : torchvision.models.resnet.ResNet + ResNet model to use for wisp prediction """ # Load pre-trained ResNet-18 model model = models.resnet18(weights='IMAGENET1K_V1') @@ -160,10 +181,14 @@ def define_options(parser=None, usage=None, conflict_handler='resolve'): if parser is None: parser = argparse.ArgumentParser(usage=usage, conflict_handler=conflict_handler) - parser.add_argument('-m', '--model_filename', type=str, default=None, help='Filename of saved ML model. (default=%(default)s)') - parser.add_argument('-s', '--starting_date', type=float, default=None, help='Earliest MJD to search for data. If None, date is retrieved from database.') - parser.add_argument('-e', '--ending_date', type=float, default=None, help='Latest MJD to search for data. If None, the current date is used.') - parser.add_argument('-f', '--file_list', type=str, nargs='+', default=None, help='List of full paths to files to run the monitor on.') + parser.add_argument('-m', '--model_filename', type=str, default=None, + help='Filename of saved ML model. (default=%(default)s)') + parser.add_argument('-s', '--starting_date', type=float, default=None, + help='Earliest MJD to search for data. If None, date is retrieved from database.') + parser.add_argument('-e', '--ending_date', type=float, default=None, + help='Latest MJD to search for data. If None, the current date is used.') + parser.add_argument('-f', '--file_list', type=str, nargs='+', default=None, + help='List of full paths to files to run the monitor on.') return parser @@ -174,6 +199,11 @@ def load_ml_model(model_filename): ---------- model_filename : str Location of file containing the model. e.g. /path/to/my_best_model.pth + + Returns + ------- + model : torchvision.models.resnet.ResNet + ResNet model to use for wisp prediction """ model = define_model_architecture() model.load_state_dict(torch.load(model_filename)) @@ -182,28 +212,56 @@ def load_ml_model(model_filename): def predict_wisp(model, image_path, transform): + """Use the model to predict whether there is a wisp in the image. The model returns + a probability. So we use a threshold to separate those predictions into 'wisp' and + 'no wisp' bins. + + Parameters + ---------- + model : torchvision.models.resnet.ResNet + ResNet model to use for wisp prediction + + image_path : str + Full path to the png file + + transform : torchvision.transforms.transforms.Compose + Image transform function used to modify the input images into the format + expected by the ML model. + + Returns + ------- + prediction_label : str + "wisp" or "no wisp" + """ image_tensor = preprocess_image(image_path, transform) # Preprocess the image with torch.no_grad(): # Make prediction without gradients output = model(image_tensor) - # Interpret the result - #_, predicted_class = torch.max(output, 1) - #class_labels = ["no wisp", "wisp"] - #prediction_label = class_labels[predicted_class.item()] - - # If your model instead outputs a single probability (e.g., for "wisp"), use a threshold - # instead of the lines above + # The model outputs a single probability (e.g., for "wisp"). So, use a threshold + # to determine whether the prediction is wisp or no_wisp. probability = torch.sigmoid(output).item() threshold = 0.5 prediction_label = "wisp" if probability >= threshold else "no wisp" - - #print(f"The model predicts: {prediction_label} with probability {probability:.2f}") return prediction_label def preprocess_image(image_path, transform): """Load the png file and prepare it for input to the model + + Parameters + ---------- + image_path : str + Path and filename of the png file + + transform : torchvision.transforms.transforms.Compose + Image transform function used to modify the input images into the format + expected by the ML model. + + Returns + ------- + image : torch.Tensor + Tensor on which the model will run """ image = Image.open(image_path).convert('RGB') # Ensure image is RGB image = transform(image) # Apply transformations @@ -213,7 +271,7 @@ def preprocess_image(image_path, transform): def query_mast(starttime, endtime): """Query MAST between the given dates. Generate a list of NRCB4 files on which - the wisp model will be applied + the wisp model will be run Parameters ---------- @@ -255,6 +313,16 @@ def remove_duplicate_files(file_list): """When running locally, it's possible to end up with duplicates of some filenames in the list of files, because the files are present in both the public and proprietary lists. This function will remove the duplicates. + + Parameters + ---------- + file_list : list + List of full paths to input files + + Returns + ------- + unique_files : list + List of files with unique basenames """ file_list = np.array(file_list) unique_files = [] @@ -265,6 +333,9 @@ def remove_duplicate_files(file_list): return unique_files +#@only_one +@log_fail +@log_info def run(model_filename=None, starting_date=None, ending_date=None, file_list=None): """Run the wisp finder monitor. From user-input dates or dates retrieved from the database, query MAST for all NIRCam NRCB4 full-frame imaging mode data. For @@ -294,6 +365,11 @@ def run(model_filename=None, starting_date=None, ending_date=None, file_list=Non if model_filename is None: model_filename = get_config()['wisp_finder_ML_model'] + if os.path.isfile(model_filename): + logging.info(f'Using ML model saved in: {model_filename}') + else: + raise FileNotFoundError(f"WARNING: {model_filename} does not exist. Unable to load ML model.") + if file_list is None: # If ending_date is not provided, set it equal to the current time @@ -307,69 +383,82 @@ def run(model_filename=None, starting_date=None, ending_date=None, file_list=Non latest_run_end = get_latest_run() starting_date = latest_run_end + logging.info(f"Using MJD {starting_date} to {ending_date} to search for files") + # Query MAST between starting_date and ending_date, and get a list of files # to run the wisp prediction on. rate_files = query_mast(starting_date, ending_date) + logging.info(f"Found {len(rate_files)} rate files") else: rate_files = file_list + starting_date = 0.0 + ending_date = 0.0 - # Find the location in the filesystem for all files - filepaths_public = files_in_filesystem(rate_files, 'public') - filepaths_proprietary = files_in_filesystem(rate_files, 'proprietary') - filepaths = filepaths_public + filepaths_proprietary - filepaths = remove_duplicate_files(filepaths) - - working_filepaths = copy_files_to_working_dir(filepaths) + if len(rate_files) > 0: + # Find the location in the filesystem for all files + logging.info("Locating files in the filesystem") + filepaths_public = files_in_filesystem(rate_files, 'public') + filepaths_proprietary = files_in_filesystem(rate_files, 'proprietary') + filepaths = filepaths_public + filepaths_proprietary + filepaths = remove_duplicate_files(filepaths) - # Load the trained ML model - model = load_ml_model(model_filename) + logging.info("Copying files from the filesystem to the working directory.") + working_filepaths = copy_files_to_working_dir(filepaths) - # Create transform to use when creating image tensor - transform = create_transform() + # Load the trained ML model + logging.info(f"Loading ML model from {model_filename}") + model = load_ml_model(model_filename) - # For each fits file, create a png file, and have the ML model predict if there is a wisp - for working_filepath in working_filepaths: + # Create transform to use when creating image tensor + transform = create_transform() + # For each fits file, create a png file, and have the ML model predict if there is a wisp + for working_filepath in working_filepaths: - # we can probably find a way to simply create an Image instance and predict, rather than - # saving and then reading in a png... - # Create png - working_dir = os.path.dirname(working_filepath) - png_filename = prepare_wisp_pngs.run(working_filepath, out_dir=working_dir) + # we can probably find a way to simply create an Image instance and predict, rather than + # saving and then reading in a png... - # Predict - prediction = predict_wisp(model, png_filename, transform) + # Create png + working_dir = os.path.dirname(working_filepath) + png_filename = prepare_wisp_pngs.run(working_filepath, out_dir=working_dir) - print(png_filename, prediction) # FOR DEVELOPMENT ONLY. REMOVE BEFORE MERGING + # Predict + prediction = predict_wisp(model, png_filename, transform) - # If a wisp is predicted, set the wisp flag in the anomalies database - if prediction == 'wisp': - print('Found wisp!!') - # Create the rootname. Strip off the path info, and remove '.fits' and the suffix - # (i.e. 'rate'') - rootfile = '_'.join(os.path.basename(working_filepath).split('.')[0].split('_')[0:-1]) + print(png_filename, prediction) # FOR DEVELOPMENT ONLY. REMOVE BEFORE MERGING - # Add the wisp flag to the RootFileInfo object for the rootfile - add_wisp_flag(rootfile) - print('Added wisp flag') - else: - print('No wisp') + # If a wisp is predicted, set the wisp flag in the anomalies database + if prediction == "wisp": + # Create the rootname. Strip off the path info, and remove '.fits' and the suffix + # (i.e. 'rate'') + rootfile = '_'.join(os.path.basename(working_filepath).split('.')[0].split('_')[0:-1]) + logging.info(f"Found wisp in {rootfile}") - # Delete the png and fits files - print(f'Removing {png_filename} and {working_filepath}') - os.remove(png_filename) - os.remove(working_filepath) + # Add the wisp flag to the RootFileInfo object for the rootfile + add_wisp_flag(rootfile) + else: + pass - # Update the database with info about this run of the monitor - if file_list is None: - do_it() + # Delete the png and fits files + os.remove(png_filename) + os.remove(working_filepath) else: - print('What dates do we add to the database in this case?') + # If no rate_files are found, + logging.info(f"No rate files found. Ending monitor run.") + + # Update the database with info about this run of the monitor. We keep the + # staring and ending dates of the search. No need to keep the names of the files + # that are found to contain a wisp, because that info will be in the RootFileInfo + # instances. + do_it() if __name__ == '__main__': + module = os.path.basename(__file__).strip('.py') + start_time, log_file = monitor_utils.initialize_instrument_monitor(module) + parser = define_options() args = parser.parse_args() @@ -377,3 +466,5 @@ def run(model_filename=None, starting_date=None, ending_date=None, file_list=Non file_list=args.file_list, starting_date=args.starting_date, ending_date=args.ending_date) + + monitor_utils.update_monitor_table(module, start_time, log_file)