Skip to content

Commit 9ca5373

Browse files
authored
Merge pull request #1700 from bhilbert4/wisp-finder-del-figs
wisp finder - nans in png, close figures
2 parents 5e867b0 + 5ae7b35 commit 9ca5373

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

jwql/instrument_monitors/nircam_monitors/prepare_wisp_pngs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import argparse
99
import numpy as np
10+
from astropy.convolution import Gaussian2DKernel, interpolate_replace_nans
1011
from astropy.io import fits
1112
from astropy.stats import sigma_clipped_stats
1213
import os
@@ -28,6 +29,28 @@ def create_figure(image, outfile):
2829
plt.imshow(image, origin='lower')
2930
plt.axis('off')
3031
plt.savefig(outfile, bbox_inches='tight')
32+
plt.close('all')
33+
34+
35+
def fill_nan_with_nearest_neighbor(arr):
36+
"""
37+
Replaces NaN values in a 2D NumPy array with values interpolated
38+
from the nearest non-NaN neighbors.
39+
40+
Parameters
41+
----------
42+
arr : numpy.ndarray
43+
A 2D NumPy array potentially containing NaN values.
44+
45+
Returns
46+
-------
47+
filled_arr : numpy.ndarray
48+
A new array with NaNs replaced by nearest neighbor interpolation.
49+
"""
50+
kernel = Gaussian2DKernel(x_stddev=1, y_stddev=1)
51+
filled_arr = interpolate_replace_nans(arr, kernel)
52+
53+
return filled_arr
3154

3255

3356
def rescale_array(arr):
@@ -126,6 +149,9 @@ def run(filename, out_dir=None):
126149
"""
127150
data = fits.getdata(filename)
128151

152+
# Replace NaN values with interpolated values from nearest neighbors
153+
data = fill_nan_with_nearest_neighbor(data)
154+
129155
# Get the basename of the input file. This will be used to create
130156
# the output png file name
131157
outfile_base = os.path.basename(filename).split('.')[0]

jwql/instrument_monitors/nircam_monitors/wisp_finder.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def define_model_architecture():
164164
def define_options(parser=None, usage=None, conflict_handler='resolve'):
165165
"""Add command line options
166166
167-
Parrameters
167+
Parameters
168168
-----------
169169
parser : argparse.parser
170170
Parser object
@@ -265,7 +265,7 @@ def predict_wisp(model, image_path, transform):
265265
probability = torch.sigmoid(output).item()
266266
threshold = 0.5
267267
prediction_label = "wisp" if probability >= threshold else "no wisp"
268-
return prediction_label
268+
return prediction_label, probability, threshold
269269

270270

271271
def preprocess_image(image_path, transform):
@@ -308,6 +308,7 @@ def query_mast(starttime, endtime):
308308
rate_files : list
309309
List of filenames
310310
"""
311+
logging.info("Running sci_obs_id query")
311312
sci_obs_id_table = Observations.query_criteria(instrument_name=["NIRCAM/IMAGE"],
312313
provenance_name=["CALJWST"], # Executed observations
313314
t_min=[starttime, endtime]
@@ -317,16 +318,18 @@ def query_mast(starttime, endtime):
317318

318319
# Loop over visits identifying uncalibrated files that are associated
319320
# with them
320-
for exposure in (sci_obs_id_table):
321+
for i, exposure in enumerate(sci_obs_id_table):
321322
products = Observations.get_product_list(exposure)
322323
filtered_products = Observations.filter_products(products,
323324
productType='SCIENCE',
324325
productSubGroupDescription='RATE',
325326
calib_level=[2])
327+
logging.info(f"\tExpore {i+1} of {len(sci_obs_id_table)}: {len(products)} products filters to {len(filtered_products)} rate files")
326328
sci_files_to_download.extend(filtered_products['dataURI'])
327329

328330
# The current ML wisp finder model is only trained for the wisps on the B4 detector,
329331
# so keep only those files. Also, keep only the filenames themselves.
332+
logging.info(f"Sorting {len(sci_files_to_download)} rate files")
330333
rate_files = sorted([fname.replace('mast:JWST/product/', '') for fname in sci_files_to_download if 'nrcb4' in fname])
331334
return rate_files
332335

@@ -477,10 +480,12 @@ def run_predictor(ratefiles, model_file, start_date, end_date):
477480

478481
# Remove any duplicates coming from files that are present in both the
479482
# public and proprietary filesystems
483+
n_filepaths_before = len(filepaths)
480484
filepaths = remove_duplicate_files(filepaths)
485+
n_filepaths_after = len(filepaths)
481486

482487
# Copy files to working directory
483-
logging.info("Copying files from the filesystem to the working directory.")
488+
logging.info(f"Copying {n_filepaths_after} files from the filesystem to the working directory (removed {n_filepaths_before - n_filepaths_after} duplicates).")
484489
working_filepaths = copy_files_to_working_dir(filepaths)
485490

486491
# Load the trained ML model
@@ -497,20 +502,20 @@ def run_predictor(ratefiles, model_file, start_date, end_date):
497502
png_filename = prepare_wisp_pngs.run(working_filepath, out_dir=working_dir)
498503

499504
# Predict
500-
prediction = predict_wisp(model, png_filename, transform)
505+
prediction, probability, threshold = predict_wisp(model, png_filename, transform)
501506

502507
# If a wisp is predicted, set the wisp flag in the anomalies database
503508
if prediction == "wisp":
504509
# Create the rootname. Strip off the path info, and remove '.fits' and the suffix
505510
# (i.e. 'rate'')
506511
rootfile = '_'.join(os.path.basename(working_filepath).split('.')[0].split('_')[0:-1])
507-
logging.info(f"\tFound wisp in {rootfile}\n")
512+
logging.info(f"\tFound wisp in {rootfile} (probability {probability} < threshold {threshold})\n\n")
508513

509514
# Add the wisp flag to the RootFileInfo object for the rootfile
510515
add_wisp_flag(rootfile)
511516
else:
512517
rootfile = '_'.join(os.path.basename(working_filepath).split('.')[0].split('_')[0:-1])
513-
logging.info(f'\tNo wisp in {rootfile}\n')
518+
logging.info(f'\tNo wisp in {rootfile} (probability {probability} < threshold {threshold})\n')
514519

515520
# Delete the png and fits files
516521
os.remove(png_filename)

jwql/utils/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,8 @@
720720
"niriss_readnoise_query_history", "niriss_readnoise_stats",
721721
"nirspec_readnoise_query_history", "nirspec_readnoise_stats",
722722
"miri_ta_query_history", "miri_ta_stats",
723-
"nirspec_ta_query_history", "nirspec_ta_stats", "nirspec_wata_stats", "nirspec_msata_stats"
723+
"nirspec_ta_query_history", "nirspec_ta_stats", "nirspec_wata_stats", "nirspec_msata_stats",
724+
"wisp_finder_b4_query_history"
724725
]
725726

726727
# Suffix for msa files

jwql/website/apps/jwql/monitor_models/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ class Meta:
149149
* Django doesn't have a built-in array data type, so you need to import it from the
150150
database-compatibility layers. The ArrayField takes, as a required argument, the type
151151
of data that makes up the array.
152-
* In the Meta sub-class of the monitor class, the `db_table_comment = 'monitors'` statement is
153-
required so that django knows that the model should be stored in the monitors table.
152+
* In order to store a table in the Monitors database (JWQLDB), you must add that table's name
153+
(Meta.db_table) to the MONITOR_TABLE_NAMES constant in jwql.utils.constants.py
154154
* The `float()` casts are required because the database interface doesn't understand
155155
numpy data types.
156156
* The `list()` cast is required because the database interface doesn't understand the

0 commit comments

Comments
 (0)