Skip to content

Commit

Permalink
Docstrings and clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
bhilbert4 committed Jan 14, 2025
1 parent efe961a commit edbbe6e
Showing 1 changed file with 152 additions and 61 deletions.
213 changes: 152 additions & 61 deletions jwql/instrument_monitors/nircam_monitors/wisp_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,22 @@
from torchvision import transforms
import torchvision.models as models

from jwql.shared_tasks.shared_tasks import only_one
from jwql.utils import monitor_utils
from jwql.utils.constants import ON_GITHUB_ACTIONS, ON_READTHEDOCS
from jwql.utils.logging_functions import log_info, log_fail
from jwql.utils.utils import get_config
from jwql.website.apps.jwql.archive_database_update import files_in_filesystem
from jwql.website.apps.jwql.models import Anomalies, RootFileInfo
from jwql.instrument_monitors.nircam_monitors import prepare_wisp_pngs

if 1>0:
if not ON_GITHUB_ACTIONS and not ON_READTHEDOCS:
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jwql.website.jwql_proj.settings")
setup()

# once we have db models defined, uncoment this line
#from jwql.website.apps.jwql.monitor_models.wisp_finder import *


def add_wisp_flag(basename):
"""Add the wisps flag to the RootFileInfo entry for the given filename
Expand Down Expand Up @@ -99,22 +106,30 @@ def copy_files_to_working_dir(filepaths):
----------
filepaths : list
List of full paths of files to be copied
Returns
-------
copied_filepaths : list
List of new locations for the files
"""
working_dir = get_config()["working"]
copied_filepaths = []
for filepath in filepaths:
shutil.copy2(filepath, working_dir)
copied_filepaths.append(os.path.join(working_dir, os.path.basename(filepath)))


print(f'copying {filepath} to {working_dir}')

logging.info(f'Copying {filepath} to {working_dir}')

return copied_filepaths


def create_transform():
"""
"""Create a transform function that will be used to modify images
and place them in the format expected by the ML model
Returns
-------
transform : torchvision.transforms.transforms.Compose
Image transform model
"""
transform = transforms.Compose([
transforms.Resize((128, 128)), # Resize images to a fixed size
Expand All @@ -125,7 +140,13 @@ def create_transform():


def define_model_architecture():
"""
"""Define the basic architecture of the ML model. This will be the framework into which
the model parameters will be loaded, in order to fully define the function.
Returns
-------
model : torchvision.models.resnet.ResNet
ResNet model to use for wisp prediction
"""
# Load pre-trained ResNet-18 model
model = models.resnet18(weights='IMAGENET1K_V1')
Expand Down Expand Up @@ -160,10 +181,14 @@ def define_options(parser=None, usage=None, conflict_handler='resolve'):
if parser is None:
parser = argparse.ArgumentParser(usage=usage, conflict_handler=conflict_handler)

parser.add_argument('-m', '--model_filename', type=str, default=None, help='Filename of saved ML model. (default=%(default)s)')
parser.add_argument('-s', '--starting_date', type=float, default=None, help='Earliest MJD to search for data. If None, date is retrieved from database.')
parser.add_argument('-e', '--ending_date', type=float, default=None, help='Latest MJD to search for data. If None, the current date is used.')
parser.add_argument('-f', '--file_list', type=str, nargs='+', default=None, help='List of full paths to files to run the monitor on.')
parser.add_argument('-m', '--model_filename', type=str, default=None,
help='Filename of saved ML model. (default=%(default)s)')
parser.add_argument('-s', '--starting_date', type=float, default=None,
help='Earliest MJD to search for data. If None, date is retrieved from database.')
parser.add_argument('-e', '--ending_date', type=float, default=None,
help='Latest MJD to search for data. If None, the current date is used.')
parser.add_argument('-f', '--file_list', type=str, nargs='+', default=None,
help='List of full paths to files to run the monitor on.')
return parser


Expand All @@ -174,6 +199,11 @@ def load_ml_model(model_filename):
----------
model_filename : str
Location of file containing the model. e.g. /path/to/my_best_model.pth
Returns
-------
model : torchvision.models.resnet.ResNet
ResNet model to use for wisp prediction
"""
model = define_model_architecture()
model.load_state_dict(torch.load(model_filename))
Expand All @@ -182,28 +212,56 @@ def load_ml_model(model_filename):


def predict_wisp(model, image_path, transform):
"""Use the model to predict whether there is a wisp in the image. The model returns
a probability. So we use a threshold to separate those predictions into 'wisp' and
'no wisp' bins.
Parameters
----------
model : torchvision.models.resnet.ResNet
ResNet model to use for wisp prediction
image_path : str
Full path to the png file
transform : torchvision.transforms.transforms.Compose
Image transform function used to modify the input images into the format
expected by the ML model.
Returns
-------
prediction_label : str
"wisp" or "no wisp"
"""
image_tensor = preprocess_image(image_path, transform) # Preprocess the image

with torch.no_grad(): # Make prediction without gradients
output = model(image_tensor)

# Interpret the result
#_, predicted_class = torch.max(output, 1)
#class_labels = ["no wisp", "wisp"]
#prediction_label = class_labels[predicted_class.item()]

# If your model instead outputs a single probability (e.g., for "wisp"), use a threshold
# instead of the lines above
# The model outputs a single probability (e.g., for "wisp"). So, use a threshold
# to determine whether the prediction is wisp or no_wisp.
probability = torch.sigmoid(output).item()
threshold = 0.5
prediction_label = "wisp" if probability >= threshold else "no wisp"

#print(f"The model predicts: {prediction_label} with probability {probability:.2f}")
return prediction_label


def preprocess_image(image_path, transform):
"""Load the png file and prepare it for input to the model
Parameters
----------
image_path : str
Path and filename of the png file
transform : torchvision.transforms.transforms.Compose
Image transform function used to modify the input images into the format
expected by the ML model.
Returns
-------
image : torch.Tensor
Tensor on which the model will run
"""
image = Image.open(image_path).convert('RGB') # Ensure image is RGB
image = transform(image) # Apply transformations
Expand All @@ -213,7 +271,7 @@ def preprocess_image(image_path, transform):

def query_mast(starttime, endtime):
"""Query MAST between the given dates. Generate a list of NRCB4 files on which
the wisp model will be applied
the wisp model will be run
Parameters
----------
Expand Down Expand Up @@ -255,6 +313,16 @@ def remove_duplicate_files(file_list):
"""When running locally, it's possible to end up with duplicates of some filenames in
the list of files, because the files are present in both the public and proprietary
lists. This function will remove the duplicates.
Parameters
----------
file_list : list
List of full paths to input files
Returns
-------
unique_files : list
List of files with unique basenames
"""
file_list = np.array(file_list)
unique_files = []
Expand All @@ -265,6 +333,9 @@ def remove_duplicate_files(file_list):
return unique_files


#@only_one
@log_fail
@log_info
def run(model_filename=None, starting_date=None, ending_date=None, file_list=None):
"""Run the wisp finder monitor. From user-input dates or dates retrieved from
the database, query MAST for all NIRCam NRCB4 full-frame imaging mode data. For
Expand Down Expand Up @@ -294,6 +365,11 @@ def run(model_filename=None, starting_date=None, ending_date=None, file_list=Non
if model_filename is None:
model_filename = get_config()['wisp_finder_ML_model']

if os.path.isfile(model_filename):
logging.info(f'Using ML model saved in: {model_filename}')
else:
raise FileNotFoundError(f"WARNING: {model_filename} does not exist. Unable to load ML model.")

if file_list is None:

# If ending_date is not provided, set it equal to the current time
Expand All @@ -307,73 +383,88 @@ def run(model_filename=None, starting_date=None, ending_date=None, file_list=Non
latest_run_end = get_latest_run()
starting_date = latest_run_end

logging.info(f"Using MJD {starting_date} to {ending_date} to search for files")

# Query MAST between starting_date and ending_date, and get a list of files
# to run the wisp prediction on.
rate_files = query_mast(starting_date, ending_date)
logging.info(f"Found {len(rate_files)} rate files")

else:
rate_files = file_list
starting_date = 0.0
ending_date = 0.0

# Find the location in the filesystem for all files
filepaths_public = files_in_filesystem(rate_files, 'public')
filepaths_proprietary = files_in_filesystem(rate_files, 'proprietary')
filepaths = filepaths_public + filepaths_proprietary
filepaths = remove_duplicate_files(filepaths)

working_filepaths = copy_files_to_working_dir(filepaths)
if len(rate_files) > 0:
# Find the location in the filesystem for all files
logging.info("Locating files in the filesystem")
filepaths_public = files_in_filesystem(rate_files, 'public')
filepaths_proprietary = files_in_filesystem(rate_files, 'proprietary')
filepaths = filepaths_public + filepaths_proprietary
filepaths = remove_duplicate_files(filepaths)

# Load the trained ML model
model = load_ml_model(model_filename)
logging.info("Copying files from the filesystem to the working directory.")
working_filepaths = copy_files_to_working_dir(filepaths)

# Create transform to use when creating image tensor
transform = create_transform()
# Load the trained ML model
logging.info(f"Loading ML model from {model_filename}")
model = load_ml_model(model_filename)

# For each fits file, create a png file, and have the ML model predict if there is a wisp
for working_filepath in working_filepaths:
# Create transform to use when creating image tensor
transform = create_transform()

# For each fits file, create a png file, and have the ML model predict if there is a wisp
for working_filepath in working_filepaths:

# we can probably find a way to simply create an Image instance and predict, rather than
# saving and then reading in a png...

# Create png
working_dir = os.path.dirname(working_filepath)
png_filename = prepare_wisp_pngs.run(working_filepath, out_dir=working_dir)
# we can probably find a way to simply create an Image instance and predict, rather than
# saving and then reading in a png...

# Predict
prediction = predict_wisp(model, png_filename, transform)
# Create png
working_dir = os.path.dirname(working_filepath)
png_filename = prepare_wisp_pngs.run(working_filepath, out_dir=working_dir)

print(png_filename, prediction) # FOR DEVELOPMENT ONLY. REMOVE BEFORE MERGING
# Predict
prediction = predict_wisp(model, png_filename, transform)

# If a wisp is predicted, set the wisp flag in the anomalies database
if prediction == 'wisp':
print('Found wisp!!')
# Create the rootname. Strip off the path info, and remove '.fits' and the suffix
# (i.e. 'rate'')
rootfile = '_'.join(os.path.basename(working_filepath).split('.')[0].split('_')[0:-1])
print(png_filename, prediction) # FOR DEVELOPMENT ONLY. REMOVE BEFORE MERGING

# Add the wisp flag to the RootFileInfo object for the rootfile
add_wisp_flag(rootfile)
print('Added wisp flag')
else:
print('No wisp')
# If a wisp is predicted, set the wisp flag in the anomalies database
if prediction == "wisp":
# Create the rootname. Strip off the path info, and remove '.fits' and the suffix
# (i.e. 'rate'')
rootfile = '_'.join(os.path.basename(working_filepath).split('.')[0].split('_')[0:-1])
logging.info(f"Found wisp in {rootfile}")

# Delete the png and fits files
print(f'Removing {png_filename} and {working_filepath}')
os.remove(png_filename)
os.remove(working_filepath)
# Add the wisp flag to the RootFileInfo object for the rootfile
add_wisp_flag(rootfile)
else:
pass

# Update the database with info about this run of the monitor
if file_list is None:
do_it()
# Delete the png and fits files
os.remove(png_filename)
os.remove(working_filepath)
else:
print('What dates do we add to the database in this case?')
# If no rate_files are found,
logging.info(f"No rate files found. Ending monitor run.")

# Update the database with info about this run of the monitor. We keep the
# staring and ending dates of the search. No need to keep the names of the files
# that are found to contain a wisp, because that info will be in the RootFileInfo
# instances.
do_it()


if __name__ == '__main__':
module = os.path.basename(__file__).strip('.py')
start_time, log_file = monitor_utils.initialize_instrument_monitor(module)

parser = define_options()
args = parser.parse_args()

run(args.model_filename,
file_list=args.file_list,
starting_date=args.starting_date,
ending_date=args.ending_date)

monitor_utils.update_monitor_table(module, start_time, log_file)

0 comments on commit edbbe6e

Please sign in to comment.