Skip to content
Merged
1 change: 1 addition & 0 deletions changes/9442.ami.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve memory usage of AmiAnalyzeStep
88 changes: 39 additions & 49 deletions jwst/ami/nrm_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,38 +82,16 @@ def fit_fringes_all(self, input_model):
"""
# scidata, dqmask are already centered around peak
self.scidata, self.dqmask = self.instrument_data.read_data_model(input_model)
self.instrument_data.isz = self.scidata.shape[1]

# list for nrm objects for each slc
self.nrm_list = []
# Initialize the output oifts models
oifits_model = oifits.RawOifits(self.instrument_data)
oifits_model.initialize_obsarrays()
oifits_model_multi = oifits.RawOifits(self.instrument_data, method="multi")
oifits_model_multi.initialize_obsarrays()

for slc in range(self.instrument_data.nslices):
log.info(f"Fitting fringes for iteration {slc} of {self.instrument_data.nslices}")
self.nrm_list.append(self.fit_fringes_single_integration(slc))

# Now save final output model(s) of all slices, averaged slices to AmiOiModels
# averaged
oifits_model = oifits.RawOifits(self)
output_model = oifits_model.make_oifits()

# multi-integration
oifits_model_multi = oifits.RawOifits(self, method="multi")
output_model_multi = oifits_model_multi.make_oifits()

# Save cropped/centered data, model, residual in AmiLgFitModel
lgfit = self.make_lgfitmodel()

return output_model, output_model_multi, lgfit

def make_lgfitmodel(self):
"""
Populate the LGFitModel with the output of the fringe fitting (LG algorithm).

Returns
-------
m : AmiLgFitModel object
LG analysis centered data, fit, residual, and model info
"""
nslices = len(self.nrm_list)
# initialize the output lgfitmodel product arrays
nslices = self.instrument_data.nslices
# 3d arrays of centered data, models, and residuals (data - model)
ctrd_arr = np.zeros((nslices, self.scidata.shape[1], self.scidata.shape[2]))
n_ctrd_arr = np.zeros((nslices, self.scidata.shape[1], self.scidata.shape[2]))
Expand All @@ -124,31 +102,43 @@ def make_lgfitmodel(self):
# Model parameters
solns_arr = np.zeros((nslices, 44))

for i, nrmslc in enumerate(self.nrm_list):
for slc in range(nslices):
log.info(f"Fitting fringes for iteration {slc} of {nslices}")
nrmslc = self.fit_fringes_single_integration(slc)

# populate the solutions of the lgfit model
datapeak = nrmslc.reference.max()
ctrd_arr[i, :, :] = nrmslc.reference
n_ctrd_arr[i, :, :] = nrmslc.reference / datapeak
model_arr[i, :, :] = nrmslc.modelpsf
n_model_arr[i, :, :] = nrmslc.modelpsf / datapeak
resid_arr[i, :, :] = nrmslc.residual
n_resid_arr[i, :, :] = nrmslc.residual / datapeak
solns_arr[i, :] = nrmslc.soln

# Populate datamodel
m = datamodels.AmiLgFitModel()
m.centered_image = ctrd_arr
m.norm_centered_image = n_ctrd_arr
m.fit_image = model_arr
m.norm_fit_image = n_model_arr
m.resid_image = resid_arr
m.norm_resid_image = n_resid_arr
m.solns_table = np.recarray(
ctrd_arr[slc, :, :] = nrmslc.reference
n_ctrd_arr[slc, :, :] = nrmslc.reference / datapeak
model_arr[slc, :, :] = nrmslc.modelpsf
n_model_arr[slc, :, :] = nrmslc.modelpsf / datapeak
resid_arr[slc, :, :] = nrmslc.residual
n_resid_arr[slc, :, :] = nrmslc.residual / datapeak
solns_arr[slc, :] = nrmslc.soln

# populate the oifits models
oifits_model.populate_obsarray(slc, nrmslc)
oifits_model_multi.populate_obsarray(slc, nrmslc)

# Populate the LGFitModel with the output of the fringe fitting (LG algorithm).
lgfit = datamodels.AmiLgFitModel()
lgfit.centered_image = ctrd_arr
lgfit.norm_centered_image = n_ctrd_arr
lgfit.fit_image = model_arr
lgfit.norm_fit_image = n_model_arr
lgfit.resid_image = resid_arr
lgfit.norm_resid_image = n_resid_arr
lgfit.solns_table = np.recarray(
solns_arr.shape[0],
dtype=[("coeffs", "f8", solns_arr.shape[1])],
buf=solns_arr,
)

return m
# Now save final output model(s) of all slices, averaged slices to AmiOiModels
output_model = oifits_model.make_oifits()
output_model_multi = oifits_model_multi.make_oifits()

return output_model, output_model_multi, lgfit

def fit_fringes_single_integration(self, slc):
"""
Expand Down
68 changes: 39 additions & 29 deletions jwst/ami/oifits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ class RawOifits:
"""
Store AMI data in the format required to write out to OIFITS files.

Takes fringefitter class, which contains nrm_list and instrument_data attributes,
all info needed to write oifits. Angular quantities of input are in radians from
fringe fitting; converted to degrees for saving. Populate the structure needed to
write out oifits files according to schema.
Builds the structure needed to write out oifits files according to the schema.
Populates the structure with the observables from the fringe fitter.
For data arrays, the observables are populated slice-by-slice to enable
fitting with FringeFitter and storage as OiFits to take place in the same loop.
Angular quantities, initially in radians from fringe fitting,
are converted to degrees for saving.
Produces averaged and multi-integration versions, with sigma-clipped stats over
integrations.

Expand All @@ -30,22 +32,22 @@ class RawOifits:
https://github.com/anand0xff/ImPlaneIA/blob/master/nrm_analysis/misctools/implane2oifits.py#L32
"""

def __init__(self, fringefitter, method="mean"):
def __init__(self, instrument_data, method="mean"):
"""
Initialize the RawOifits object.

Parameters
----------
fringefitter : FringeFitter object
Object containing nrm_list attribute (list of nrm objects)
and other info needed for OIFITS files
instrument_data : jwst.ami.instrument_data.NIRISS object
Information on the mask geometry (namely # holes), instrument,
wavelength obs mode.
method : str
Method to average observables: mean or median. Default mean.
"""
self.fringe_fitter = fringefitter
self.n_holes = 7
self.instrument_data = instrument_data

self.nslices = len(self.fringe_fitter.nrm_list) # n ints
self.nslices = self.instrument_data.nslices # n ints
self.n_baselines = int(comb(self.n_holes, 2)) # 21
self.n_closure_phases = int(comb(self.n_holes, 3)) # 35
self.n_closure_amplitudes = int(comb(self.n_holes, 4)) # also 35
Expand All @@ -60,15 +62,15 @@ def __init__(self, fringefitter, method="mean"):
log.warning(msg)
self.method = "mean"

self.ctrs_eqt = self.fringe_fitter.instrument_data.ctrs_eqt
self.ctrs_inst = self.fringe_fitter.instrument_data.ctrs_inst
self.ctrs_eqt = self.instrument_data.ctrs_eqt
self.ctrs_inst = self.instrument_data.ctrs_inst

self.bholes, self.bls = self._makebaselines()
self.tholes, self.tuv = self._maketriples_all()
self.qholes, self.quads = self._makequads_all()

def make_obsarrays(self):
"""Make arrays of observables of the correct shape for saving to datamodels."""
def initialize_obsarrays(self):
"""Initialize arrays of observables to empty arrays."""
# empty arrays of observables, (nslices,nobservables) shape.
self.fringe_phases = np.zeros((self.nslices, self.n_baselines))
self.fringe_amplitudes = np.zeros((self.nslices, self.n_baselines))
Expand All @@ -77,21 +79,30 @@ def make_obsarrays(self):
self.q4_phases = np.zeros((self.nslices, self.n_closure_amplitudes))
self.closure_amplitudes = np.zeros((self.nslices, self.n_closure_amplitudes))
self.pistons = np.zeros((self.nslices, self.n_holes))
# model parameters
self.solns = np.zeros((self.nslices, 44))
self.fringe_amplitudes_squared = np.zeros((self.nslices, self.n_baselines))

def populate_obsarray(self, i, nrmslc):
"""
Populate arrays of observables with fringe fitter results.

Parameters
----------
i : int
Index of the integration
nrmslc : object
Object containing the results of the fringe fitting for this integration
"""
# populate with each integration's observables
for i, nrmslc in enumerate(self.fringe_fitter.nrm_list):
self.fringe_phases[i, :] = nrmslc.fringephase # FPs in radians
self.fringe_amplitudes[i, :] = nrmslc.fringeamp
self.closure_phases[i, :] = nrmslc.redundant_cps # CPs in radians
self.t3_amplitudes[i, :] = nrmslc.t3_amplitudes
self.q4_phases[i, :] = nrmslc.q4_phases # quad phases in radians
self.closure_amplitudes[i, :] = nrmslc.redundant_cas
self.pistons[i, :] = nrmslc.fringepistons # segment pistons in radians
self.solns[i, :] = nrmslc.soln

self.fringe_amplitudes_squared = self.fringe_amplitudes**2 # squared visibilities
self.fringe_phases[i, :] = nrmslc.fringephase # FPs in radians
self.fringe_amplitudes[i, :] = nrmslc.fringeamp
self.closure_phases[i, :] = nrmslc.redundant_cps # CPs in radians
self.t3_amplitudes[i, :] = nrmslc.t3_amplitudes
self.q4_phases[i, :] = nrmslc.q4_phases # quad phases in radians
self.closure_amplitudes[i, :] = nrmslc.redundant_cas
self.pistons[i, :] = nrmslc.fringepistons # segment pistons in radians
self.solns[i, :] = nrmslc.soln
self.fringe_amplitudes_squared[i, :] = nrmslc.fringeamp**2 # squared visibilities

def rotate_matrix(self, cov_mat, theta):
"""
Expand Down Expand Up @@ -329,8 +340,7 @@ def make_oifits(self):
m : AmiOIModel
Fully populated datamodel
"""
self.make_obsarrays()
instrument_data = self.fringe_fitter.instrument_data
instrument_data = self.instrument_data
observation_date = Time(
f"{instrument_data.year}-{instrument_data.month}-{instrument_data.day}",
format="fits",
Expand Down Expand Up @@ -445,7 +455,7 @@ def make_oifits(self):

pscale = instrument_data.pscale_mas / 1000.0 # arcsec
# Size of the image to extract NRM data
isz = self.fringe_fitter.scidata.shape[1]
isz = self.instrument_data.isz
fov = [pscale * isz] * self.n_holes
fovtype = ["RADIUS"] * self.n_holes

Expand Down