diff --git a/changes/9171.extract_2d.rst b/changes/9171.extract_2d.rst new file mode 100644 index 0000000000..47b7b2b040 --- /dev/null +++ b/changes/9171.extract_2d.rst @@ -0,0 +1,2 @@ +Add ability to select multiple slit names or source ids in extract_2d + diff --git a/docs/jwst/extract_2d/main.rst b/docs/jwst/extract_2d/main.rst index 982098b249..fe8775112f 100644 --- a/docs/jwst/extract_2d/main.rst +++ b/docs/jwst/extract_2d/main.rst @@ -297,10 +297,16 @@ Step Arguments The ``extract_2d`` step has various optional arguments that apply to certain observation modes. For NIRSpec observations there is one applicable argument: -``--slit_name`` - name [string value] of a specific slit region to extract. The default value of None +``--slit_names`` + names [comma-separated list containing integers or strings] of specific slits to extract. The default value of None will cause all known slits for the instrument mode to be extracted. +``--source_ids`` + source_ids [comma-separated list containing integers or strings] of specific slits to extract. The default value + of None will cause all known slits for the instrument to be extracted. + +``slit_names`` and ``source_ids`` can be used at the same time, duplicates will be filtered out. + There are several arguments available for Wide-Field Slitless Spectroscopy (WFSS) and Time-Series (TSO) grism spectroscopy: diff --git a/jwst/extract_2d/extract_2d.py b/jwst/extract_2d/extract_2d.py index 2fc8f7f0a6..f80e72d6e3 100644 --- a/jwst/extract_2d/extract_2d.py +++ b/jwst/extract_2d/extract_2d.py @@ -12,7 +12,8 @@ def extract2d(input_model, - slit_name=None, + slit_names=None, + source_ids=None, reference_files={}, grism_objects=None, tsgrism_extract_height=None, @@ -27,8 +28,10 @@ def extract2d(input_model, ---------- input_model : `~jwst.datamodels.ImageModel` or `~jwst.datamodels.CubeModel` Input data model. - slit_name : str or int - Slit name. + slit_names : list containing strings or ints + Slit names to be processed. + source_ids : list containing strings or ints + Source ids to be processed. reference_files : dict Reference files. grism_objects : list @@ -66,7 +69,7 @@ def extract2d(input_model, log.info(f'EXP_TYPE {exp_type} with grating=MIRROR not supported for extract 2D') input_model.meta.cal_step.extract_2d = 'SKIPPED' return input_model - output_model = nrs_extract2d(input_model, slit_name=slit_name) + output_model = nrs_extract2d(input_model, slit_names=slit_names, source_ids=source_ids) elif exp_type in slitless_modes: if exp_type == 'NRC_TSGRISM': if tsgrism_extract_height is None: diff --git a/jwst/extract_2d/extract_2d_step.py b/jwst/extract_2d/extract_2d_step.py index 4eb1a8b72d..0e085815bc 100755 --- a/jwst/extract_2d/extract_2d_step.py +++ b/jwst/extract_2d/extract_2d_step.py @@ -16,7 +16,8 @@ class Extract2dStep(Step): class_alias = "extract_2d" spec = """ - slit_name = string(default=None) + slit_names = force_list(default=None) # slits to be extracted + source_ids = force_list(default=None) # source ids to be extracted extract_orders = int_list(default=None) # list of orders to extract grism_objects = list(default=None) # list of grism objects to use tsgrism_extract_height = integer(default=None) # extraction height in pixels, TSGRISM mode @@ -32,9 +33,9 @@ def process(self, input_model, *args, **kwargs): for reftype in self.reference_file_types: reffile = self.get_reference_file(input_model, reftype) reference_file_names[reftype] = reffile if reffile else "" - with datamodels.open(input_model) as dm: - output_model = extract_2d.extract2d(dm, self.slit_name, + output_model = extract_2d.extract2d(dm, self.slit_names, + self.source_ids, reference_files=reference_file_names, grism_objects=self.grism_objects, tsgrism_extract_height=self.tsgrism_extract_height, diff --git a/jwst/extract_2d/nirspec.py b/jwst/extract_2d/nirspec.py index e6f432811d..e4798b0439 100644 --- a/jwst/extract_2d/nirspec.py +++ b/jwst/extract_2d/nirspec.py @@ -18,7 +18,7 @@ log.setLevel(logging.DEBUG) -def nrs_extract2d(input_model, slit_name=None): +def nrs_extract2d(input_model, slit_names=None, source_ids=None): """ Main extract_2d function for NIRSpec exposures. @@ -26,8 +26,10 @@ def nrs_extract2d(input_model, slit_name=None): ---------- input_model : `~jwst.datamodels.ImageModel` or `~jwst.datamodels.CubeModel` Input data model. - slit_name : str or int - Slit name. + slit_names : list containing strings or ints + Slit names. + source_ids : list containing strings or ints + Source ids. """ exp_type = input_model.meta.exposure.type.upper() @@ -45,17 +47,7 @@ def nrs_extract2d(input_model, slit_name=None): # This is a kludge but will work for now. # This model keeps open_slits as an attribute. open_slits = slit2msa.slits[:] - if slit_name is not None: - new_open_slits = [] - slit_name = str(slit_name) - for sub in open_slits: - if str(sub.name) == slit_name: - new_open_slits.append(sub) - break - if len(new_open_slits) == 0: - raise AttributeError("Slit {} not in open slits.".format(slit_name)) - open_slits = new_open_slits - + open_slits = select_slits(open_slits, slit_names, source_ids) # NIRSpec BRIGHTOBJ (S1600A1 TSO) mode if exp_type == 'NRS_BRIGHTOBJ': # the output model is a single SlitModel @@ -118,6 +110,56 @@ def nrs_extract2d(input_model, slit_name=None): return output_model +def select_slits(open_slits, slit_names, source_ids): + """ + Select the slits to process given the full set of open slits and the + slit_names and source_ids lists + + Parameters + ---------- + + open_slits : list + list of open slits + slit_names : list + list of slit names to process + source_ids : list + list of source ids to process + """ + open_slit_names = [str(x.name) for x in open_slits] + open_slit_source_ids = [str(x.source_id) for x in open_slits] + selected_open_slits = [] + if slit_names is not None: + matched_slits = [] + for this_slit in [str(x) for x in slit_names]: + if this_slit in open_slit_names: + matched_slits.append(this_slit) + else: + log.warn(f"Slit {this_slit} is not in the list of open slits.") + for sub in open_slits: + if str(sub.name) in matched_slits: + selected_open_slits.append(sub) + if source_ids is not None: + matched_sources = [] + for this_id in [str(x) for x in source_ids]: + if this_id in open_slit_source_ids: + matched_sources.append(this_id) + else: + log.warn(f"Source id {this_id} is not in the list of open slits.") + for sub in open_slits: + if str(sub.source_id) in matched_sources: + if sub not in selected_open_slits: + selected_open_slits.append(sub) + else: + log.info(f"Source_id {sub.source_id} already selected (name {sub.name})") + if len(selected_open_slits) > 0: + log.info("Slits selected:") + for this_slit in selected_open_slits: + log.info(f"Name: {this_slit.name}, source_id: {this_slit.source_id}") + return selected_open_slits + else: + log.info("All slits selected") + return open_slits + def process_slit(input_model, slit, exp_type): """ diff --git a/jwst/extract_2d/tests/test_nirspec.py b/jwst/extract_2d/tests/test_nirspec.py index 536bea0246..cbfb4c5d07 100644 --- a/jwst/extract_2d/tests/test_nirspec.py +++ b/jwst/extract_2d/tests/test_nirspec.py @@ -3,9 +3,11 @@ from astropy.io import fits from astropy.table import Table from stdatamodels.jwst.datamodels import ImageModel, CubeModel, MultiSlitModel, SlitModel +from stdatamodels.jwst.transforms.models import Slit from jwst.assign_wcs import AssignWcsStep from jwst.extract_2d.extract_2d_step import Extract2dStep +from jwst.extract_2d.nirspec import select_slits # WCS keywords, borrowed from NIRCam grism tests @@ -120,6 +122,19 @@ def create_msa_hdul(): return hdul +def create_list_of_slits(): + """ + Each slit is a stdatamodels.jwst.transforms.model.Slit instance. + The only attributes that are used by select_slits() are + name and source_id + """ + slit_list = [] + name_list = ['1', '2', '3', '4', '5', '6', 'S200A1'] + source_id_list = ['1000', '2000', '3000', '4000', '5000', '6000', '2222'] + for name, source_id in zip(name_list, source_id_list): + new_slit = Slit(name=name, source_id=source_id) + slit_list.append(new_slit) + return slit_list @pytest.fixture def nirspec_msa_rate(tmp_path): @@ -159,6 +174,9 @@ def nirspec_msa_metfl(tmp_path): hdul.close() return filename +@pytest.fixture +def nirspec_slit_list(): + return create_list_of_slits() def test_extract_2d_nirspec_msa_fs(nirspec_msa_rate, nirspec_msa_metfl): model = ImageModel(nirspec_msa_rate) @@ -228,3 +246,29 @@ def test_extract_2d_nirspec_bots(nirspec_bots_rateints): model.close() result.close() + +def test_select_slits(nirspec_slit_list): + slit_list = nirspec_slit_list + # Choose all slits + all_slits = select_slits(slit_list, None, None) + assert all_slits == slit_list + # Just slit with name=='3' + single_name = select_slits(slit_list, ['3'], None) + assert single_name[0].name == '3' + assert single_name[0].source_id == '3000' + # Just slit with source_id == '3000' + single_id = select_slits(slit_list, None, ['2000']) + assert single_id[0].name == '2' + assert single_id[0].source_id == '2000' + # Slits with name == '4' and source_id == '4000' are duplicates + duplicates = select_slits(slit_list, ['1', '4'], ['2000', '4000']) + assert Slit(name='1', source_id='1000') in duplicates + assert Slit(name='2', source_id='2000') in duplicates + assert Slit(name='4', source_id='4000') in duplicates + # Select slit with a non-integer name + non_integer = select_slits(slit_list, ['S200A1'], None) + assert non_integer[0] == Slit(name='S200A1', source_id='2222') + # Select slits with mix of integer and string names + with_integer = select_slits(slit_list, [2, '5'], None) + assert Slit(name='2', source_id='2000') in with_integer + assert Slit(name='5', source_id='5000') in with_integer \ No newline at end of file