Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/9171.extract_2d.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add ability to select multiple slit names or source ids in extract_2d

10 changes: 8 additions & 2 deletions docs/jwst/extract_2d/main.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,16 @@ Step Arguments
The ``extract_2d`` step has various optional arguments that apply to certain observation
modes. For NIRSpec observations there is one applicable argument:

``--slit_name``
name [string value] of a specific slit region to extract. The default value of None
``--slit_names``
names [comma-separated list containing integers or strings] of specific slits to extract. The default value of None
will cause all known slits for the instrument mode to be extracted.

``--source_ids``
source_ids [comma-separated list containing integers or strings] of specific slits to extract. The default value
of None will cause all known slits for the instrument to be extracted.

``slit_names`` and ``source_ids`` can be used at the same time, duplicates will be filtered out.

There are several arguments available for Wide-Field Slitless Spectroscopy (WFSS) and
Time-Series (TSO) grism spectroscopy:

Expand Down
11 changes: 7 additions & 4 deletions jwst/extract_2d/extract_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


def extract2d(input_model,
slit_name=None,
slit_names=None,
source_ids=None,
reference_files={},
grism_objects=None,
tsgrism_extract_height=None,
Expand All @@ -27,8 +28,10 @@ def extract2d(input_model,
----------
input_model : `~jwst.datamodels.ImageModel` or `~jwst.datamodels.CubeModel`
Input data model.
slit_name : str or int
Slit name.
slit_names : list containing strings or ints
Slit names to be processed.
source_ids : list containing strings or ints
Source ids to be processed.
reference_files : dict
Reference files.
grism_objects : list
Expand Down Expand Up @@ -66,7 +69,7 @@ def extract2d(input_model,
log.info(f'EXP_TYPE {exp_type} with grating=MIRROR not supported for extract 2D')
input_model.meta.cal_step.extract_2d = 'SKIPPED'
return input_model
output_model = nrs_extract2d(input_model, slit_name=slit_name)
output_model = nrs_extract2d(input_model, slit_names=slit_names, source_ids=source_ids)
elif exp_type in slitless_modes:
if exp_type == 'NRC_TSGRISM':
if tsgrism_extract_height is None:
Expand Down
7 changes: 4 additions & 3 deletions jwst/extract_2d/extract_2d_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class Extract2dStep(Step):
class_alias = "extract_2d"

spec = """
slit_name = string(default=None)
slit_names = force_list(default=None) # slits to be extracted
source_ids = force_list(default=None) # source ids to be extracted
extract_orders = int_list(default=None) # list of orders to extract
grism_objects = list(default=None) # list of grism objects to use
tsgrism_extract_height = integer(default=None) # extraction height in pixels, TSGRISM mode
Expand All @@ -32,9 +33,9 @@ def process(self, input_model, *args, **kwargs):
for reftype in self.reference_file_types:
reffile = self.get_reference_file(input_model, reftype)
reference_file_names[reftype] = reffile if reffile else ""

with datamodels.open(input_model) as dm:
output_model = extract_2d.extract2d(dm, self.slit_name,
output_model = extract_2d.extract2d(dm, self.slit_names,
self.source_ids,
reference_files=reference_file_names,
grism_objects=self.grism_objects,
tsgrism_extract_height=self.tsgrism_extract_height,
Expand Down
70 changes: 56 additions & 14 deletions jwst/extract_2d/nirspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
log.setLevel(logging.DEBUG)


def nrs_extract2d(input_model, slit_name=None):
def nrs_extract2d(input_model, slit_names=None, source_ids=None):
"""
Main extract_2d function for NIRSpec exposures.

Parameters
----------
input_model : `~jwst.datamodels.ImageModel` or `~jwst.datamodels.CubeModel`
Input data model.
slit_name : str or int
Slit name.
slit_names : list containing strings or ints
Slit names.
source_ids : list containing strings or ints
Source ids.
"""
exp_type = input_model.meta.exposure.type.upper()

Expand All @@ -45,17 +47,7 @@
# This is a kludge but will work for now.
# This model keeps open_slits as an attribute.
open_slits = slit2msa.slits[:]
if slit_name is not None:
new_open_slits = []
slit_name = str(slit_name)
for sub in open_slits:
if str(sub.name) == slit_name:
new_open_slits.append(sub)
break
if len(new_open_slits) == 0:
raise AttributeError("Slit {} not in open slits.".format(slit_name))
open_slits = new_open_slits

open_slits = select_slits(open_slits, slit_names, source_ids)
# NIRSpec BRIGHTOBJ (S1600A1 TSO) mode
if exp_type == 'NRS_BRIGHTOBJ':
# the output model is a single SlitModel
Expand Down Expand Up @@ -118,6 +110,56 @@

return output_model

def select_slits(open_slits, slit_names, source_ids):
"""
Select the slits to process given the full set of open slits and the
slit_names and source_ids lists

Parameters
----------

open_slits : list
list of open slits
slit_names : list
list of slit names to process
source_ids : list
list of source ids to process
"""
open_slit_names = [str(x.name) for x in open_slits]
open_slit_source_ids = [str(x.source_id) for x in open_slits]
selected_open_slits = []
if slit_names is not None:
matched_slits = []
for this_slit in [str(x) for x in slit_names]:
if this_slit in open_slit_names:
matched_slits.append(this_slit)
else:
log.warn(f"Slit {this_slit} is not in the list of open slits.")

Check warning on line 137 in jwst/extract_2d/nirspec.py

View check run for this annotation

Codecov / codecov/patch

jwst/extract_2d/nirspec.py#L137

Added line #L137 was not covered by tests
for sub in open_slits:
if str(sub.name) in matched_slits:
selected_open_slits.append(sub)
if source_ids is not None:
matched_sources = []
for this_id in [str(x) for x in source_ids]:
if this_id in open_slit_source_ids:
matched_sources.append(this_id)
else:
log.warn(f"Source id {this_id} is not in the list of open slits.")

Check warning on line 147 in jwst/extract_2d/nirspec.py

View check run for this annotation

Codecov / codecov/patch

jwst/extract_2d/nirspec.py#L147

Added line #L147 was not covered by tests
for sub in open_slits:
if str(sub.source_id) in matched_sources:
if sub not in selected_open_slits:
selected_open_slits.append(sub)
else:
log.info(f"Source_id {sub.source_id} already selected (name {sub.name})")
if len(selected_open_slits) > 0:
log.info("Slits selected:")
for this_slit in selected_open_slits:
log.info(f"Name: {this_slit.name}, source_id: {this_slit.source_id}")
return selected_open_slits
else:
log.info("All slits selected")
return open_slits


def process_slit(input_model, slit, exp_type):
"""
Expand Down
44 changes: 44 additions & 0 deletions jwst/extract_2d/tests/test_nirspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from astropy.io import fits
from astropy.table import Table
from stdatamodels.jwst.datamodels import ImageModel, CubeModel, MultiSlitModel, SlitModel
from stdatamodels.jwst.transforms.models import Slit

from jwst.assign_wcs import AssignWcsStep
from jwst.extract_2d.extract_2d_step import Extract2dStep
from jwst.extract_2d.nirspec import select_slits


# WCS keywords, borrowed from NIRCam grism tests
Expand Down Expand Up @@ -120,6 +122,19 @@ def create_msa_hdul():

return hdul

def create_list_of_slits():
"""
Each slit is a stdatamodels.jwst.transforms.model.Slit instance.
The only attributes that are used by select_slits() are
name and source_id
"""
slit_list = []
name_list = ['1', '2', '3', '4', '5', '6', 'S200A1']
source_id_list = ['1000', '2000', '3000', '4000', '5000', '6000', '2222']
for name, source_id in zip(name_list, source_id_list):
new_slit = Slit(name=name, source_id=source_id)
slit_list.append(new_slit)
return slit_list

@pytest.fixture
def nirspec_msa_rate(tmp_path):
Expand Down Expand Up @@ -159,6 +174,9 @@ def nirspec_msa_metfl(tmp_path):
hdul.close()
return filename

@pytest.fixture
def nirspec_slit_list():
return create_list_of_slits()

def test_extract_2d_nirspec_msa_fs(nirspec_msa_rate, nirspec_msa_metfl):
model = ImageModel(nirspec_msa_rate)
Expand Down Expand Up @@ -228,3 +246,29 @@ def test_extract_2d_nirspec_bots(nirspec_bots_rateints):

model.close()
result.close()

def test_select_slits(nirspec_slit_list):
slit_list = nirspec_slit_list
# Choose all slits
all_slits = select_slits(slit_list, None, None)
assert all_slits == slit_list
# Just slit with name=='3'
single_name = select_slits(slit_list, ['3'], None)
assert single_name[0].name == '3'
assert single_name[0].source_id == '3000'
# Just slit with source_id == '3000'
single_id = select_slits(slit_list, None, ['2000'])
assert single_id[0].name == '2'
assert single_id[0].source_id == '2000'
# Slits with name == '4' and source_id == '4000' are duplicates
duplicates = select_slits(slit_list, ['1', '4'], ['2000', '4000'])
assert Slit(name='1', source_id='1000') in duplicates
assert Slit(name='2', source_id='2000') in duplicates
assert Slit(name='4', source_id='4000') in duplicates
# Select slit with a non-integer name
non_integer = select_slits(slit_list, ['S200A1'], None)
assert non_integer[0] == Slit(name='S200A1', source_id='2222')
# Select slits with mix of integer and string names
with_integer = select_slits(slit_list, [2, '5'], None)
assert Slit(name='2', source_id='2000') in with_integer
assert Slit(name='5', source_id='5000') in with_integer