Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/9728.wfss_contam.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug where incorrect pixel locations were fed into the detector-to-grism transform, leading to incorrectly-shaped model traces
16 changes: 9 additions & 7 deletions jwst/extract_2d/grisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,6 @@ def extract_grism_objects(
order_model = Const1D(order)
order_model.inverse = Const1D(order)

tr = inwcs.get_transform("grism_detector", "detector")
tr = (
Mapping((0, 1, 0, 0, 0))
| (Shift(xmin) & Shift(ymin) & xcenter_model & ycenter_model & order_model)
| tr
)
y_slice = slice(_toindex(ymin), _toindex(ymax) + 1)
x_slice = slice(_toindex(xmin), _toindex(xmax) + 1)

Expand All @@ -499,10 +493,18 @@ def extract_grism_objects(
else:
var_flat = None

# add a new transform to the WCS that shifts to the center of the virtual slit
tr = Mapping((0, 1, 0, 0, 0)) | (
Shift(xmin) & Shift(ymin) & xcenter_model & ycenter_model & order_model
)
bind_bounding_box(
tr, util.transform_bbox_from_shape(ext_data.shape, order="F"), order="F"
)
subwcs.set_transform("grism_detector", "detector", tr)
grism_slit = copy.deepcopy(subwcs.grism_detector)
grism_slit.name = "grism_slit"
subwcs.insert_frame(
input_frame=grism_slit, output_frame="grism_detector", transform=tr
)

new_slit = datamodels.SlitModel(
data=ext_data,
Expand Down
39 changes: 12 additions & 27 deletions jwst/wfss_contam/disperse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

import numpy as np
from astropy.modeling.mappings import Mapping
from scipy import sparse

from jwst.lib.winclip import get_clipped_pixels
Expand All @@ -23,8 +24,6 @@ def _determine_native_wl_spacing(
wmin,
wmax,
oversample_factor=2,
xoffset=0,
yoffset=0,
):
"""
Determine the wavelength spacing necessary to adequately sample the dispersed frame.
Expand All @@ -47,10 +46,6 @@ def _determine_native_wl_spacing(
Maximum wavelength for dispersed spectra
oversample_factor : int, optional
Factor by which to oversample the wavelength grid
xoffset : int, optional
X offset to apply to the dispersed pixel positions
yoffset : int, optional
Y offset to apply to the dispersed pixel positions

Returns
-------
Expand All @@ -68,8 +63,8 @@ def _determine_native_wl_spacing(
# Convert to x/y in the direct image frame
x0_xy, y0_xy, _, _ = sky_to_imgxy(x0_sky, y0_sky, 1, order)
# then convert to x/y in the grism image frame.
xwmin, ywmin = imgxy_to_grismxy(x0_xy + xoffset, y0_xy + yoffset, wmin, order)
xwmax, ywmax = imgxy_to_grismxy(x0_xy + xoffset, y0_xy + yoffset, wmax, order)
xwmin, ywmin = imgxy_to_grismxy(x0_xy, y0_xy, wmin, order)
xwmax, ywmax = imgxy_to_grismxy(x0_xy, y0_xy, wmax, order)
dxw = xwmax - xwmin
dyw = ywmax - ywmin

Expand All @@ -80,9 +75,7 @@ def _determine_native_wl_spacing(
return lambdas


def _disperse_onto_grism(
x0_sky, y0_sky, sky_to_imgxy, imgxy_to_grismxy, lambdas, order, xoffset=0, yoffset=0
):
def _disperse_onto_grism(x0_sky, y0_sky, sky_to_imgxy, imgxy_to_grismxy, lambdas, order):
"""
Compute x/y positions in the grism image for the set of desired wavelengths.

Expand All @@ -100,10 +93,6 @@ def _disperse_onto_grism(
Wavelengths at which to compute dispersed pixel values
order : int
Spectral order number
xoffset : int, optional
X offset to apply to the dispersed pixel positions
yoffset : int, optional
Y offset to apply to the dispersed pixel positions

Returns
-------
Expand All @@ -123,7 +112,7 @@ def _disperse_onto_grism(

# Convert to x/y in grism frame.
lambdas = np.repeat(lambdas[:, np.newaxis], x0_xy.shape[1], axis=1)
x0s, y0s = imgxy_to_grismxy(x0_xy + xoffset, y0_xy + yoffset, lambdas, order)
x0s, y0s = imgxy_to_grismxy(x0_xy, y0_xy, lambdas, order)
# x0s, y0s now have shape (n_lam, n_pixels)
return x0s, y0s, lambdas

Expand Down Expand Up @@ -210,8 +199,6 @@ def disperse(
grism_wcs,
naxis,
oversample_factor=2,
xoffset=0,
yoffset=0,
):
"""
Compute the dispersed image pixel values from the direct image.
Expand Down Expand Up @@ -244,10 +231,6 @@ def disperse(
Dimensions of the grism image (naxis[0], naxis[1])
oversample_factor : int, optional
Factor by which to oversample the wavelength grid
xoffset : float, optional
X offset to apply to the dispersed pixel positions
yoffset : float, optional
Y offset to apply to the dispersed pixel positions

Returns
-------
Expand All @@ -269,8 +252,14 @@ def disperse(
sky_to_imgxy = grism_wcs.get_transform("world", "detector")
imgxy_to_grismxy = grism_wcs.get_transform("detector", "grism_detector")

# We only need the x,y outputs of imgxy_to_grismxy
# Making the number of outputs dynamic handles legacy WCS objects that did not pass
# the x0, y0, and order through the transform unmodified like the current version does.
n_outputs = len(imgxy_to_grismxy.outputs)
imgxy_to_grismxy = imgxy_to_grismxy | Mapping((0, 1), n_inputs=n_outputs)

# Find RA/Dec of the input pixel position in segmentation map
x0_sky, y0_sky = seg_wcs(x0, y0)
x0_sky, y0_sky = seg_wcs(x0, y0, with_bounding_box=False)

# native spacing does not change much over the detector, so just put in one x0, y0
lambdas = _determine_native_wl_spacing(
Expand All @@ -282,8 +271,6 @@ def disperse(
wmin,
wmax,
oversample_factor=oversample_factor,
xoffset=xoffset,
yoffset=yoffset,
)
nlam = len(lambdas)

Expand All @@ -294,8 +281,6 @@ def disperse(
imgxy_to_grismxy,
lambdas,
order,
xoffset=xoffset,
yoffset=yoffset,
)

# If none of the dispersed pixel indexes are within the image frame,
Expand Down
9 changes: 0 additions & 9 deletions jwst/wfss_contam/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def __init__(
filter_name,
source_id=None,
boundaries=None,
offsets=None,
max_cpu=1,
max_pixels_per_chunk=5e4,
oversample_factor=2,
Expand All @@ -139,8 +138,6 @@ def __init__(
ID of source to process. If 0, all sources processed.
boundaries : list, optional, default []
Start/Stop coordinates of the FOV within the larger seed image.
offsets : list, optional, default [0,0]
Offset values for x and y axes
max_cpu : int, optional, default 1
Max number of cpu's to use when multiprocessing
max_pixels_per_chunk : int, optional, default 1e5
Expand All @@ -150,8 +147,6 @@ def __init__(
"""
if boundaries is None:
boundaries = []
if offsets is None:
offsets = [0, 0]
# Load all the info for this grism mode
self.seg_wcs = segmap_model.meta.wcs
self.grism_wcs = grism_wcs
Expand All @@ -164,8 +159,6 @@ def __init__(
self.max_cpu = max_cpu
self.max_pixels_per_chunk = max_pixels_per_chunk
self.oversample_factor = oversample_factor
self.xoffset = offsets[0]
self.yoffset = offsets[1]

# ensure the direct image has background subtracted
self.dimage = background_subtract(direct_image)
Expand Down Expand Up @@ -271,8 +264,6 @@ def chunk_sources(self, order, wmin, wmax, sens_waves, sens_response, max_pixels
self.grism_wcs,
self.naxis,
self.oversample_factor,
self.xoffset,
self.yoffset,
]
)

Expand Down
30 changes: 4 additions & 26 deletions jwst/wfss_contam/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import warnings

import asdf
import numpy as np
import pytest
import stdatamodels.jwst.datamodels as dm
from astropy.convolution import convolve
from astropy.stats import sigma_clipped_stats
from astropy.table import Table
from astropy.utils.data import get_pkg_data_filename
from photutils.datasets import make_100gaussians_image
from photutils.segmentation import SourceFinder, make_2dgaussian_kernel

from jwst.assign_wcs.tests.test_niriss import create_imaging_wcs, create_wfss_wcs

DIR_IMAGE = "direct_image.fits"


Expand Down Expand Up @@ -69,18 +67,7 @@ def segmentation_map(direct_image):

# turn this into a jwst datamodel
model = dm.SegmentationMapModel(data=segm.data)
with warnings.catch_warnings():
# asdf.exceptions.AsdfPackageVersionWarning in oldestdeps job
warnings.filterwarnings(
"ignore",
message="File .* was created with extension URI .* which is not currently installed",
)
with asdf.open(
get_pkg_data_filename("data/segmentation_wcs.asdf", package="jwst.wfss_contam.tests")
) as asdf_file:
wcsobj = asdf_file.tree["wcs"]
model.meta.wcs = wcsobj

model.meta.wcs = create_imaging_wcs("F200W")
return model


Expand Down Expand Up @@ -121,13 +108,4 @@ def grism_wcs():
-----
This should probably be mocked in future updates.
"""
with warnings.catch_warnings():
# asdf.exceptions.AsdfPackageVersionWarning in oldestdeps job
warnings.filterwarnings(
"ignore",
message="File .* was created with extension URI .* which is not currently installed",
)
with asdf.open(
get_pkg_data_filename("data/grism_wcs.asdf", package="jwst.wfss_contam.tests")
) as asdf_file:
return asdf_file.tree["wcs"]
return create_wfss_wcs("GR150C", pupil="F200W")
Loading
Loading