Skip to content

Commit f76fe12

Browse files
JP-3930: Step function for opening models (#9723)
1 parent 9850766 commit f76fe12

File tree

12 files changed

+600
-69
lines changed

12 files changed

+600
-69
lines changed

changes/9723.stpipe.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a ``prepare_output`` method to the JwstStep class, to support input datamodel handling and conditional copies within processing steps.

jwst/assign_mtwcs/assign_mtwcs_step.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@ def process(self, input_lib):
3535
`~jwst.datamodels.library.ModelLibrary`
3636
The modified data models.
3737
"""
38-
if not isinstance(input_lib, ModelLibrary):
38+
# Open the input, making a copy if necessary
39+
output_lib = self.prepare_output(input_lib)
40+
if not isinstance(output_lib, ModelLibrary):
3941
try:
40-
input_lib = ModelLibrary(input_lib)
41-
except Exception:
42+
output_lib = ModelLibrary(output_lib)
43+
except (ValueError, TypeError) as err:
4244
log.warning("Input data type is not supported.")
43-
record_step_status(input_lib, "assign_mtwcs", False)
44-
return input_lib
45+
log.debug(f"Error was: {err}")
46+
record_step_status(output_lib, "assign_mtwcs", False)
47+
return output_lib
4548

46-
result = assign_moving_target_wcs(input_lib)
49+
result = assign_moving_target_wcs(output_lib)
50+
record_step_status(result, "assign_mtwcs", True)
4751
return result

jwst/assign_mtwcs/tests/test_mtwcs.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from stdatamodels.jwst import datamodels
44

55
from jwst.assign_mtwcs import AssignMTWcsStep
6-
from jwst.datamodels import ModelLibrary
6+
from jwst.datamodels import ModelContainer, ModelLibrary
77

88

99
@pytest.mark.parametrize("errtype", ["ra", "dec", "both", "none"])
@@ -70,3 +70,47 @@ def test_mt_slitmodel(errtype):
7070
zero = result.borrow(0)
7171
assert zero.meta.wcs.output_frame.name == expected_frame
7272
result.shelve(zero, 0, modify=False)
73+
74+
75+
@pytest.mark.parametrize("success", [True, False])
76+
def test_output_is_not_input(monkeypatch, success):
77+
"""
78+
Test that input is not modified by the step.
79+
80+
This is specific to the use case of calling the step on non-library
81+
model input. When the input is already a ModelLibrary, it's assumed
82+
that performance is the most important thing and extra copies are
83+
not desired.
84+
"""
85+
# Mock a failure in the ModelLibrary init, to exercise the "skipped" condition
86+
if not success:
87+
88+
def raise_error(*args, **kwargs):
89+
raise ValueError("test")
90+
91+
monkeypatch.setattr(ModelLibrary, "__init__", raise_error)
92+
93+
file_path = get_pkg_data_filename("data/test_mt_asn.json", package="jwst.assign_mtwcs.tests")
94+
with datamodels.open(file_path) as container:
95+
result = AssignMTWcsStep.call(container)
96+
if success:
97+
assert isinstance(result, ModelLibrary)
98+
else:
99+
assert isinstance(result, ModelContainer)
100+
with result:
101+
for im, input_im in zip(result, container):
102+
if success:
103+
assert im.meta.cal_step.assign_mtwcs == "COMPLETE"
104+
else:
105+
assert im.meta.cal_step.assign_mtwcs == "SKIPPED"
106+
assert im is not input_im
107+
assert input_im.meta.cal_step.assign_mtwcs is None
108+
109+
if success:
110+
result.shelve(im, modify=False)
111+
112+
113+
def test_input_not_supported(caplog):
114+
input_data = datamodels.ImageModel()
115+
AssignMTWcsStep.call(input_data)
116+
assert "Input data type is not supported" in caplog.text

jwst/datamodels/container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def from_asn(self, asn_data):
310310
if attr in RECOGNIZED_MEMBER_FIELDS:
311311
if attr == "tweakreg_catalog":
312312
if val.strip():
313-
val = asn_dir / val
313+
val = str(asn_dir / val)
314314
else:
315315
val = None
316316

jwst/outlier_detection/outlier_detection_step.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,17 @@ def process(self, input_data):
7676
`~jwst.datamodels.library.ModelLibrary`
7777
The modified input data with DQ flags set for detected outliers.
7878
"""
79-
# Open the input data, making a copy as needed.
80-
result_models = self._open_models(input_data)
81-
8279
# determine the "mode" (if not set by the pipeline)
8380
mode = self._guess_mode(input_data)
81+
82+
# Open the input data, making a copy as needed.
83+
if mode != "imaging":
84+
result_models = self.prepare_output(input_data)
85+
else:
86+
# Skip loading datamodels into memory in this case - allow the
87+
# ModelLibrary to handle it, later
88+
result_models = self.prepare_output(input_data, open_models=False)
89+
8490
if mode is None:
8591
record_step_status(result_models, "outlier_detection", False)
8692
return result_models
@@ -233,48 +239,3 @@ def _get_asn_id(self, input_models):
233239
self._make_output_path = partial(_make_output_path, asn_id=asn_id)
234240
log.info(f"Outlier Detection asn_id: {asn_id}")
235241
return
236-
237-
def _open_models(self, input_models):
238-
"""
239-
Open the input data, making a copy if necessary.
240-
241-
If the input data is a filename or path, it is opened
242-
and the open model is returned.
243-
244-
If it is a list of models, it is opened as a ModelContainer.
245-
In this case, or if the input is a simple datamodel or a
246-
ModelContainer, a deep copy of the model/container is returned,
247-
in order to avoid modifying the input models.
248-
249-
If the input is a ModelLibrary, it is simply returned, in order
250-
to avoid making unnecessary copies for performance-critical
251-
use cases.
252-
253-
Parameters
254-
----------
255-
input_models : str, list, JwstDataModel, ModelContainer, or ModelLibrary
256-
Input data to open.
257-
258-
Returns
259-
-------
260-
JwstDataModel, ModelContainer, or ModelLibrary
261-
The opened datamodel(s).
262-
"""
263-
# Check whether input contains datamodels
264-
make_copy = False
265-
if isinstance(input_models, list):
266-
is_datamodel = [isinstance(m, datamodels.JwstDataModel) for m in input_models]
267-
if any(is_datamodel):
268-
make_copy = True
269-
elif isinstance(input_models, (datamodels.JwstDataModel, ModelContainer)):
270-
make_copy = True
271-
272-
if not isinstance(input_models, (datamodels.JwstDataModel, ModelLibrary, ModelContainer)):
273-
# Input might be a filename or path.
274-
# In that case, open it.
275-
input_models = datamodels.open(input_models)
276-
if make_copy:
277-
# For regular models, make a copy to avoid modifying the input.
278-
# Leave libraries alone for memory management reasons.
279-
input_models = input_models.copy()
280-
return input_models

jwst/skymatch/skymatch_step.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,15 @@ def process(self, input_models):
7171
ModelLibrary
7272
A library of datamodels with the skymatch step applied.
7373
"""
74-
if isinstance(input_models, ModelLibrary):
75-
library = input_models
74+
# Check the input for open models and make a copy if necessary
75+
# to avoid modifying input data.
76+
# If there are no open models already, do not open them. Leave
77+
# that to the ModelLibrary call below.
78+
output_models = self.prepare_output(input_models, open_models=False)
79+
if isinstance(output_models, ModelLibrary):
80+
library = output_models
7681
else:
77-
library = ModelLibrary(input_models, on_disk=not self.in_memory)
82+
library = ModelLibrary(output_models, on_disk=not self.in_memory)
7883

7984
# Method: "user". Use user-provided sky values, and bypass skymatch() altogether.
8085
if self.skymethod == "user":

jwst/skymatch/tests/test_skymatch.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,13 @@ def test_skymatch(nircam_rate, skymethod, subtract, skystat, match_down, grouped
208208
)
209209

210210
if skymethod == "match" and grouped:
211-
# nothing to "match" when there is only one group:
212-
assert im.meta.background.method is None
213-
assert im.meta.background.subtracted is None
214-
215211
# test that output models have original sky levels on failure:
216212
with result:
217213
for im, lev in zip(result, levels):
214+
# nothing to "match" when there is only one group:
215+
assert im.meta.background.method is None
216+
assert im.meta.background.subtracted is None
217+
218218
assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01
219219
result.shelve(im, modify=False)
220220

@@ -603,3 +603,37 @@ def test_user_skyfile(tmp_cwd, nircam_rate, subtract):
603603

604604
with pytest.raises(ValueError):
605605
SkyMatchStep.call(container, skymethod="user", skylist=skyfile)
606+
607+
608+
@pytest.mark.parametrize("success", [True, False])
609+
def test_output_is_not_input(nircam_rate, success):
610+
"""
611+
Test that input is not modified by the step.
612+
613+
This is specific to the use case of calling the step on non-library
614+
model input. When the input is already a ModelLibrary, it's assumed
615+
that performance is the most important thing and extra copies are
616+
not desired.
617+
"""
618+
im1 = nircam_rate.copy()
619+
im2 = im1.copy()
620+
im3 = im1.copy()
621+
if success:
622+
# Set sequence IDs to allow skymatch to proceed.
623+
# Without this, the step will be skipped.
624+
im2.meta.observation.sequence_id = "2"
625+
im3.meta.observation.sequence_id = "3"
626+
container = ModelContainer([im1, im2, im3])
627+
628+
result = SkyMatchStep.call(container)
629+
630+
with result:
631+
for im, input_im in zip(result, container):
632+
if success:
633+
assert im.meta.cal_step.skymatch == "COMPLETE"
634+
else:
635+
assert im.meta.cal_step.skymatch == "SKIPPED"
636+
assert im is not input_im
637+
assert input_im.meta.cal_step.skymatch is None
638+
639+
result.shelve(im, modify=False)

jwst/stpipe/core.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,92 @@ def load_as_level3_asn(self, obj):
145145
update_key_value(asn, "expname", (), mod_func=self.make_input_path)
146146
return asn
147147

148+
def prepare_output(self, init, make_copy=None, open_models=True, **kwargs):
149+
"""
150+
Open the input data as a model, making a copy if necessary.
151+
152+
If the input data is a filename or path, it is opened
153+
and the open model is returned.
154+
155+
If it is a list of models, it is opened as a ModelContainer.
156+
In this case, or if the input is a simple datamodel or a
157+
ModelContainer, a deep copy of the model/container is returned,
158+
in order to avoid modifying the input models.
159+
160+
If the input is a ModelLibrary, it is simply returned, in order
161+
to avoid making unnecessary copies for performance-critical
162+
use cases.
163+
164+
All copies are skipped if this step has a parent (i.e. it is
165+
called as part of a pipeline).
166+
167+
Set make_copy explicitly to True or False to override the above
168+
behavior.
169+
170+
Parameters
171+
----------
172+
init : str, list, JwstDataModel, ModelContainer, or ModelLibrary
173+
Input data to open.
174+
make_copy : bool or None
175+
If True, a copy of the input will always be made.
176+
If False, a copy will never be made. If None, a copy is
177+
conditionally made, depending on the input and whether the
178+
step is called in a standalone context.
179+
open_models : bool
180+
If True and the input is a filename or list of filenames,
181+
then datamodels.open will be called to open the input.
182+
If False, the input is returned as is.
183+
**kwargs
184+
Additional keyword arguments to pass to datamodels.open. Used
185+
only if the input is a str or list.
186+
187+
Returns
188+
-------
189+
JwstDataModel, ModelContainer, or ModelLibrary
190+
The opened datamodel(s).
191+
192+
Raises
193+
------
194+
TypeError
195+
If make_copy=True and the input is a type that cannot be copied.
196+
"""
197+
# Check whether input contains datamodels
198+
copy_needed = False
199+
if isinstance(init, list):
200+
is_datamodel = [isinstance(m, datamodels.JwstDataModel) for m in init]
201+
if any(is_datamodel):
202+
# Make the list into a ModelContainer, since it contains models
203+
init = ModelContainer(init)
204+
copy_needed = True
205+
elif isinstance(init, (datamodels.JwstDataModel, ModelContainer)):
206+
copy_needed = True
207+
208+
# Input might be a filename or path.
209+
# In that case, open it if desired.
210+
if not isinstance(init, (datamodels.JwstDataModel, ModelLibrary, ModelContainer)):
211+
if open_models:
212+
input_models = datamodels.open(init, **kwargs)
213+
else:
214+
input_models = init
215+
else:
216+
# Use the init model directly.
217+
input_models = init
218+
219+
# Make a copy if needed
220+
if make_copy is None:
221+
make_copy = copy_needed and self.parent is None
222+
if make_copy:
223+
try:
224+
input_models = input_models.copy()
225+
except AttributeError:
226+
# This should only happen if make_copy is explicitly set to
227+
# True and the input is a string or a ModelLibrary.
228+
raise TypeError(
229+
f"Copy is not possible for input type {type(input_models)}"
230+
) from None
231+
232+
return input_models
233+
148234
def finalize_result(self, result, reference_files_used):
149235
"""
150236
Update the result with the software version and reference files used.

0 commit comments

Comments
 (0)