"""
Model grism spectra in individual FLTs
"""
import os
import glob
from collections import OrderedDict
import copy
import traceback
import numpy as np
import scipy.ndimage as nd
import matplotlib.pyplot as plt
import astropy.io.fits as pyfits
from astropy.table import Table
import astropy.wcs as pywcs
import astropy.units as u
#import stwcs
# Helper functions from a document written by Pirzkal, Brammer & Ryan
from . import grismconf
from . import utils
from . import GRIZLI_PATH
# Would prefer 'nearest' but that occasionally segment faults out
SEGMENTATION_INTERP = 'nearest'
# Factors for converting HST countrates to Flamba flux densities
photflam_list = {'F098M': 6.0501324882418389e-20,
'F105W': 3.038658152508547e-20,
'F110W': 1.5274130068787271e-20,
'F125W': 2.2483414275260141e-20,
'F140W': 1.4737154005353565e-20,
'F160W': 1.9275637653833683e-20,
'F435W': 3.1871480286278679e-19,
'F606W': 7.8933594352047833e-20,
'F775W': 1.0088466875014488e-19,
'F814W': 7.0767633156044843e-20,
'VISTAH': 1.9275637653833683e-20*0.95,
'GRISM': 1.e-20,
'G150': 1.e-20,
'G800L': 1.,
'G280': 1.,
'F444W': 1.e-20,
'F115W': 1.,
'F150W': 1.,
'F200W': 1.}
# Filter pivot wavelengths
photplam_list = {'F098M': 9864.722728110915,
'F105W': 10551.046906405772,
'F110W': 11534.45855553774,
'F125W': 12486.059785775655,
'F140W': 13922.907350356367,
'F160W': 15369.175708965562,
'F435W': 4328.256914042873,
'F606W': 5921.658489236346,
'F775W': 7693.297933335407,
'F814W': 8058.784799323767,
'VISTAH': 1.6433e+04,
'GRISM': 1.6e4, # WFIRST/Roman
'G150': 1.46e4, # WFIRST/Roman
'G800L': 7.4737026e3,
'G280': 3651.,
'F070W': 7.043e+03, # NIRCam
'F090W': 9.023e+03,
'F115W': 1.150e+04, # NIRISS
'F150W': 1.493e+04, # NIRISS
'F200W': 1.993e+04, # NIRISS
'F150W2': 1.658e+04,
'F140M': 1.405e+04,
'F158M': 1.582e+04, # NIRISS
'F162M': 1.627e+04,
'F182M': 1.845e+04,
'F210M': 2.096e+04,
'F164N': 1.645e+04,
'F187N': 1.874e+04,
'F212N': 2.121e+04,
'F277W': 2.758e+04,
'F356W': 3.568e+04,
'F444W': 4.404e+04,
'F322W2': 3.232e+04,
'F250M': 2.503e+04,
'F300M': 2.987e+04,
'F335M': 3.362e+04,
'F360M': 3.624e+04,
'F380M': 3.825e+04, # NIRISS
'F410M': 4.082e+04,
'F430M': 4.280e+04,
'F460M': 4.626e+04,
'F480M': 4.816e+04,
'F323N': 3.237e+04,
'F405N': 4.052e+04,
'F466N': 4.654e+04,
'F470N': 4.708e+04}
# character to skip clearing line on STDOUT printing
#no_newline = '\x1b[1A\x1b[1M'
# Demo for computing photflam and photplam with pysynphot
if False:
import pysynphot as S
n = 1.e-20
spec = S.FlatSpectrum(n, fluxunits='flam')
photflam_list = {}
photplam_list = {}
for filter in ['F098M', 'F105W', 'F110W', 'F125W', 'F140W', 'F160W', 'G102', 'G141']:
bp = S.ObsBandpass('wfc3,ir,{0}'.format(filter.lower()))
photplam_list[filter] = bp.pivot()
obs = S.Observation(spec, bp)
photflam_list[filter] = n/obs.countrate()
for filter in ['F435W', 'F606W', 'F775W', 'F814W']:
bp = S.ObsBandpass('acs,wfc1,{0}'.format(filter.lower()))
photplam_list[filter] = bp.pivot()
obs = S.Observation(spec, bp)
photflam_list[filter] = n/obs.countrate()
[docs]class GrismDisperser(object):
def __init__(self, id=0, direct=None,
segmentation=None, origin=[500, 500],
xcenter=0., ycenter=0., pad=(0,0), grow=1, beam='A',
conf=['WFC3', 'F140W', 'G141'], scale=1.,
fwcpos=None, MW_EBV=0., yoffset=0, xoffset=None):
"""Object for computing dispersed model spectra
Parameters
----------
id : int
Only consider pixels in the segmentation image with value `id`.
Default of zero to match the default empty segmentation image.
direct : `~numpy.ndarray`
Direct image cutout in f_lambda units (i.e., e-/s times PHOTFLAM).
Default is a trivial zeros array.
segmentation : `~numpy.ndarray` (float32) or None
Segmentation image. If None, create a zeros array with the same
shape as `direct`.
origin : [int, int]
`origin` defines the lower left pixel index (y,x) of the `direct`
cutout from a larger detector-frame image
xcenter, ycenter : float, float
Sub-pixel centering of the exact center of the object, relative
to the center of the thumbnail. Needed for getting exact
wavelength grid correct for the extracted 2D spectra.
pad : int, int
Offset between origin = [0,0] and the true lower left pixel of the
detector frame. This can be nonzero for cases where one creates
a direct image that extends beyond the boundaries of the nominal
detector frame to model spectra at the edges.
grow : int >= 1
Interlacing factor.
beam : str
Spectral order to compute. Must be defined in `self.conf.beams`
conf : [str, str, str] or `grismconf.aXeConf` object.
Pre-loaded aXe-format configuration file object or if list of
strings determine the appropriate configuration filename with
`grismconf.get_config_filename` and load it.
scale : float
Multiplicative factor to apply to the modeled spectrum from
`compute_model`.
fwcpos : float
Rotation position of the NIRISS filter wheel
MW_EBV : float
Galactic extinction
yoffset : float
Cross-dispersion offset to apply to the trace
xoffset : float
Dispersion offset to apply to the trace
Attributes
----------
sh : 2-tuple
shape of the direct array
sh_beam : 2-tuple
computed shape of the 2D spectrum
seg : `~numpy.array`
segmentation array
lam : `~numpy.array`
wavelength along the trace
ytrace : `~numpy.array`
y pixel center of the trace. Has same dimensions as sh_beam[1].
sensitivity : `~numpy.array`
conversion factor from native e/s to f_lambda flux densities
lam_beam, ytrace_beam, sensitivity_beam : `~numpy.array`
Versions of the above attributes defined for just the specific
pixels of the pixel beam, not the full 2D extraction.
modelf, model : `~numpy.array`, `~numpy.ndarray`
2D model spectrum. `model` is linked to `modelf` with "reshape",
the later which is a flattened 1D array where the fast
calculations are actually performed.
model : `~numpy.ndarray`
2D model spectrum linked to `modelf` with reshape.
slx_parent, sly_parent : slice
slices defined relative to `origin` to match the location of the
computed 2D spectrum.
total_flux : float
Total f_lambda flux in the thumbail within the segmentation
region.
"""
self.id = id
# lower left pixel of the `direct` array in native detector
# coordinates
self.origin = origin
if isinstance(pad, int):
self.pad = [pad, pad]
else:
self.pad = pad
self.grow = grow
# Galactic extinction
self.MW_EBV = MW_EBV
self.init_galactic_extinction(self.MW_EBV)
self.fwcpos = fwcpos
self.scale = scale
# Direct image
if direct is None:
direct = np.zeros((20, 20), dtype=np.float32)
self.direct = direct
self.sh = self.direct.shape
if self.direct.dtype is not np.float32:
self.direct = np.asarray(self.direct,dtype=np.float32)
# Segmentation image, defaults to all zeros
if segmentation is None:
#self.seg = np.zeros_like(self.direct, dtype=np.float32)
empty = np.zeros_like(self.direct, dtype=np.float32)
self.set_segmentation(empty)
else:
self.set_segmentation(segmentation.astype(np.float32))
# Initialize attributes
self.spectrum_1d = None
self.is_cgs = False
self.xc = self.sh[1]/2+self.origin[1]
self.yc = self.sh[0]/2+self.origin[0]
# Sub-pixel centering of the exact center of the object, relative
# to the center of the thumbnail
self.xcenter = xcenter
self.ycenter = ycenter
self.beam = beam
# Config file
if isinstance(conf, list):
conf_f = grismconf.get_config_filename(*conf)
self.conf = grismconf.load_grism_config(conf_f)
else:
self.conf = conf
# Get Pixel area map (xxx need to add test for WFC3)
self.PAM_value = self.get_PAM_value(verbose=False)
self.process_config()
self.yoffset = yoffset
if xoffset is not None:
self.xoffset = xoffset
if (yoffset != 0) | (xoffset is not None):
#print('yoffset!', yoffset)
self.add_ytrace_offset(yoffset)
[docs] def set_segmentation(self, seg_array):
"""
Set Segmentation array and `total_flux`.
"""
self.seg = seg_array*1
self.seg_ids = list(np.unique(self.seg))
try:
self.total_flux = self.direct[self.seg == self.id].sum()
if self.total_flux == 0:
self.total_flux = 1
except:
self.total_flux = 1.
[docs] def init_galactic_extinction(self, MW_EBV=0., R_V=utils.MW_RV):
"""
Initialize Fitzpatrick 99 Galactic extinction
Parameters
----------
MW_EBV : float
Local E(B-V)
R_V : float
Relation between specific and total extinction,
``a_v = r_v * ebv``.
Returns
-------
Sets `self.MW_F99` attribute, which is a callable function that
returns the extinction for a supplied array of wavelengths.
If MW_EBV <= 0, then sets `self.MW_F99 = None`.
"""
self.MW_F99 = None
if MW_EBV > 0:
self.MW_F99 = utils.MW_F99(MW_EBV*R_V, r_v=R_V)
[docs] def process_config(self):
"""Process grism config file
Parameters
----------
none
Returns
-------
Sets attributes that define how the dispersion is computed. See the
attributes list for `~grizli.model.GrismDisperser`.
"""
from .utils_numba import interp
# Get dispersion parameters at the reference position
self.dx = self.conf.dxlam[self.beam] # + xcenter #-xoffset
if self.grow > 1:
self.dx = np.arange(self.dx[0]*self.grow, self.dx[-1]*self.grow)
xoffset = 0.
if ('G14' in self.conf.conf_file) & (self.beam == 'A'):
xoffset = -0.5 # necessary for WFC3/IR G141, v4.32
# xoffset = 0. # suggested by ACS
# xoffset = -2.5 # test
self.xoffset = xoffset
self.ytrace_beam, self.lam_beam = self.conf.get_beam_trace(
x=(self.xc+self.xcenter-self.pad[1])/self.grow,
y=(self.yc+self.ycenter-self.pad[0])/self.grow,
dx=(self.dx+self.xcenter*0+self.xoffset)/self.grow,
beam=self.beam, fwcpos=self.fwcpos)
self.ytrace_beam *= self.grow
# Integer trace
# Add/subtract 20 for handling int of small negative numbers
dyc = np.asarray(self.ytrace_beam+20,dtype=int)-20+1
# Account for pixel centering of the trace
self.yfrac_beam = self.ytrace_beam - np.floor(self.ytrace_beam)
# Interpolate the sensitivity curve on the wavelength grid.
ysens = self.lam_beam*0
so = np.argsort(self.lam_beam)
conf_sens = self.conf.sens[self.beam]
if self.MW_F99 is not None:
MWext = 10**(-0.4*(self.MW_F99(conf_sens['WAVELENGTH']*u.AA)))
else:
MWext = 1.
ysens[so] = interp.interp_conserve_c(self.lam_beam[so],
conf_sens['WAVELENGTH'],
conf_sens['SENSITIVITY']*MWext,
integrate=1, left=0, right=0)
self.lam_sort = so
# Needs term of delta wavelength per pixel for flux densities
# dl = np.abs(np.append(self.lam_beam[1] - self.lam_beam[0],
# np.diff(self.lam_beam)))
# ysens *= dl#*1.e-17
self.sensitivity_beam = ysens
# Initialize the model arrays
self.NX = len(self.dx)
self.sh_beam = (self.sh[0], self.sh[1]+self.NX)
self.modelf = np.zeros(np.prod(self.sh_beam), dtype=np.float32)
self.model = self.modelf.reshape(self.sh_beam)
self.idx = np.arange(self.modelf.size,
dtype=np.int64).reshape(self.sh_beam)
# Indices of the trace in the flattened array
self.x0 = np.array(self.sh, dtype=np.int64) // 2
self.x0 -= 1 # zero index!
self.dxpix = self.dx - self.dx[0] + self.x0[1] # + 1
try:
self.flat_index = self.idx[dyc + self.x0[0], self.dxpix]
except IndexError:
#print('Index Error', id, dyc.dtype, self.dxpix.dtype, self.x0[0], self.xc, self.yc, self.beam, self.ytrace_beam.max(), self.ytrace_beam.min())
raise IndexError
# Trace, wavelength, sensitivity across entire 2D array
self.dxfull = np.arange(self.sh_beam[1], dtype=int)
self.dxfull += self.dx[0]-self.x0[1]
# self.ytrace, self.lam = self.conf.get_beam_trace(x=self.xc,
# y=self.yc, dx=self.dxfull, beam=self.beam)
self.ytrace, self.lam = self.conf.get_beam_trace(
x=(self.xc+self.xcenter-self.pad[1])/self.grow,
y=(self.yc+self.ycenter-self.pad[0])/self.grow,
dx=(self.dxfull+self.xcenter+xoffset)/self.grow,
beam=self.beam, fwcpos=self.fwcpos)
self.ytrace *= self.grow
ysens = self.lam*0
so = np.argsort(self.lam)
ysens[so] = interp.interp_conserve_c(self.lam[so],
conf_sens['WAVELENGTH'],
conf_sens['SENSITIVITY']*MWext,
integrate=1, left=0, right=0)
# dl = np.abs(np.append(self.lam[1] - self.lam[0],
# np.diff(self.lam)))
# ysens *= dl#*1.e-17
self.sensitivity = ysens
# Slices of the parent array based on the origin parameter
self.slx_parent = slice(self.origin[1] + self.dxfull[0] + self.x0[1],
self.origin[1] + self.dxfull[-1] + self.x0[1]+1)
self.sly_parent = slice(self.origin[0], self.origin[0] + self.sh[0])
# print 'XXX wavelength: %s %s %s' %(self.lam[-5:], self.lam_beam[-5:], dl[-5:])
[docs] def add_ytrace_offset(self, yoffset):
"""Add an offset in Y to the spectral trace
Parameters
----------
yoffset : float
Y-offset to apply
"""
from .utils_numba.interp import interp_conserve_c
self.ytrace_beam, self.lam_beam = self.conf.get_beam_trace(
x=(self.xc+self.xcenter-self.pad[1])/self.grow,
y=(self.yc+self.ycenter-self.pad[0])/self.grow,
dx=(self.dx+self.xcenter*0+self.xoffset)/self.grow,
beam=self.beam, fwcpos=self.fwcpos)
self.ytrace_beam *= self.grow
self.yoffset = yoffset
self.ytrace_beam += yoffset
# Integer trace
# Add/subtract 20 for handling int of small negative numbers
dyc = np.asarray(self.ytrace_beam+20,dtype=int)-20+1
# Account for pixel centering of the trace
self.yfrac_beam = self.ytrace_beam - np.floor(self.ytrace_beam)
try:
self.flat_index = self.idx[dyc + self.x0[0], self.dxpix]
except IndexError:
# print 'Index Error', id, self.x0[0], self.xc, self.yc, self.beam, self.ytrace_beam.max(), self.ytrace_beam.min()
raise IndexError
# Trace, wavelength, sensitivity across entire 2D array
self.ytrace, self.lam = self.conf.get_beam_trace(
x=(self.xc+self.xcenter-self.pad[1])/self.grow,
y=(self.yc+self.ycenter-self.pad[0])/self.grow,
dx=(self.dxfull+self.xcenter+self.xoffset)/self.grow,
beam=self.beam, fwcpos=self.fwcpos)
self.ytrace *= self.grow
self.ytrace += yoffset
# Reset sensitivity
ysens = self.lam_beam*0
so = np.argsort(self.lam_beam)
conf_sens = self.conf.sens[self.beam]
if self.MW_F99 is not None:
MWext = 10**(-0.4*(self.MW_F99(conf_sens['WAVELENGTH']*u.AA)))
else:
MWext = 1.
ysens[so] = interp_conserve_c(self.lam_beam[so],
conf_sens['WAVELENGTH'],
conf_sens['SENSITIVITY']*MWext,
integrate=1, left=0, right=0)
self.lam_sort = so
self.sensitivity_beam = ysens
# Full array
ysens = self.lam*0
so = np.argsort(self.lam)
ysens[so] = interp_conserve_c(self.lam[so],
conf_sens['WAVELENGTH'],
conf_sens['SENSITIVITY']*MWext,
integrate=1, left=0, right=0)
self.sensitivity = ysens
[docs] def compute_model(self, id=None, thumb=None, spectrum_1d=None,
in_place=True, modelf=None, scale=None, is_cgs=False,
apply_sensitivity=True, reset=True):
"""Compute a model 2D grism spectrum
Parameters
----------
id : int
Only consider pixels in the segmentation image (`self.seg`) with
values equal to `id`.
thumb : `~numpy.ndarray` with shape = `self.sh` or None
Optional direct image. If `None` then use `self.direct`.
spectrum_1d : [`~numpy.array`, `~numpy.array`] or None
Optional 1D template [wave, flux] to use for the 2D grism model.
If `None`, then implicitly assumes flat f_lambda spectrum.
in_place : bool
If True, put the 2D model in `self.model` and `self.modelf`,
otherwise put the output in a clean array or preformed `modelf`.
modelf : `~numpy.array` with shape = `self.sh_beam`
Preformed (flat) array to which the 2D model is added, if
`in_place` is False.
scale : float or None
Multiplicative factor to apply to the modeled spectrum.
is_cgs : bool
Units of `spectrum_1d` fluxes are f_lambda cgs.
Returns
-------
model : `~numpy.ndarray`
If `in_place` is False, returns the 2D model spectrum. Otherwise
the result is stored in `self.model` and `self.modelf`.
"""
from .utils_numba import disperse
from .utils_numba import interp
if id is None:
id = self.id
total_flux = self.total_flux
else:
self.id = id
total_flux = self.direct[self.seg == id].sum()
# Template (1D) spectrum interpolated onto the wavelength grid
if in_place:
self.spectrum_1d = spectrum_1d
if scale is None:
scale = self.scale
else:
self.scale = scale
if spectrum_1d is not None:
xspec, yspec = spectrum_1d
scale_spec = np.zeros_like(self.sensitivity_beam)
int_func = interp.interp_conserve_c
scale_spec[self.lam_sort] = int_func(self.lam_beam[self.lam_sort],
xspec, yspec)*scale
else:
scale_spec = scale
self.is_cgs = is_cgs
if is_cgs:
scale_spec /= total_flux
# Output data, fastest is to compute in place but doesn't zero-out
# previous result
if in_place:
self.modelf *= (1-reset)
modelf = self.modelf
else:
if modelf is None:
modelf = self.modelf*(1-reset)
# Optionally use a different direct image
if thumb is None:
thumb = self.direct
else:
if thumb.shape != self.sh:
print("""
Error: `thumb` must have the same dimensions as the direct image! ({0:d},{1:d})
""".format(self.sh[0], self.sh[1]))
return False
# Now compute the dispersed spectrum using the C helper
if apply_sensitivity:
sens_curve = self.sensitivity_beam
else:
sens_curve = 1.
nonz = (sens_curve*scale_spec) != 0
if (nonz.sum() > 0) & (id in self.seg_ids):
status = disperse.disperse_grism_object(thumb, self.seg,
np.float32(id),
self.flat_index[nonz],
self.yfrac_beam[nonz].astype(np.float64),
(sens_curve*scale_spec)[nonz].astype(np.float64),
modelf,
self.x0,
np.array(self.sh, dtype=np.int64),
self.x0,
np.array(self.sh_beam, dtype=np.int64))
#print('yyy PAM')
modelf /= self.PAM_value # = self.get_PAM_value()
if not in_place:
return modelf
else:
self.model = modelf.reshape(self.sh_beam)
return True
[docs] def init_optimal_profile(self, seg_ids=None):
"""Initilize optimal extraction profile
"""
if seg_ids is None:
ids = [self.id]
else:
ids = seg_ids
for i, id in enumerate(ids):
if hasattr(self, 'psf_params'):
m_i = self.compute_model_psf(id=id, in_place=False)
else:
m_i = self.compute_model(id=id, in_place=False)
#print('Add {0} to optimal profile'.format(id))
if i == 0:
m = m_i
else:
m += m_i
m = m.reshape(self.sh_beam)
m[m < 0] = 0
self.optimal_profile = m/m.sum(axis=0)
[docs] def contained_in_full_array(self, full_array):
"""Check if subimage slice is fully contained within larger array
"""
sh = full_array.shape
if (self.sly_parent.start < 0) | (self.slx_parent.start < 0):
return False
if (self.sly_parent.stop >= sh[0]) | (self.slx_parent.stop >= sh[1]):
return False
return True
[docs] def add_to_full_image(self, data, full_array):
"""Add spectrum cutout back to the full array
`data` is *added* to `full_array` in place, so, for example, to
subtract `self.model` from the full array, call the function with
>>> self.add_to_full_image(-self.model, full_array)
Parameters
----------
data : `~numpy.ndarray` shape `self.sh_beam` (e.g., `self.model`)
Spectrum cutout
full_array : `~numpy.ndarray`
Full detector array, where the lower left pixel of `data` is given
by `origin`.
"""
if self.contained_in_full_array(full_array):
full_array[self.sly_parent, self.slx_parent] += data
else:
sh = full_array.shape
xpix = np.arange(self.sh_beam[1])
xpix += self.origin[1] + self.dxfull[0] + self.x0[1]
ypix = np.arange(self.sh_beam[0])
ypix += self.origin[0]
okx = (xpix >= 0) & (xpix < sh[1])
oky = (ypix >= 0) & (ypix < sh[1])
if (okx.sum() == 0) | (oky.sum() == 0):
return False
sly = slice(ypix[oky].min(), ypix[oky].max()+1)
slx = slice(xpix[okx].min(), xpix[okx].max()+1)
full_array[sly, slx] += data[oky, :][:, okx]
# print sly, self.sly_parent, slx, self.slx_parent
return True
[docs] def cutout_from_full_image(self, full_array):
"""Get beam-sized cutout from a full image
Parameters
----------
full_array : `~numpy.ndarray`
Array of the size of the parent array from which the cutout was
extracted. If possible, the function first tries the slices with
>>> sub = full_array[self.sly_parent, self.slx_parent]
and then computes smaller slices for cases where the beam spectrum
falls off the edge of the parent array.
Returns
-------
cutout : `~numpy.ndarray`
Array with dimensions of `self.model`.
"""
# print self.sly_parent, self.slx_parent, full_array.shape
if self.contained_in_full_array(full_array):
data = full_array[self.sly_parent, self.slx_parent]
else:
sh = full_array.shape
###
xpix = np.arange(self.sh_beam[1])
xpix += self.origin[1] + self.dxfull[0] + self.x0[1]
ypix = np.arange(self.sh_beam[0])
ypix += self.origin[0]
okx = (xpix >= 0) & (xpix < sh[1])
oky = (ypix >= 0) & (ypix < sh[1])
if (okx.sum() == 0) | (oky.sum() == 0):
return False
sly = slice(ypix[oky].min(), ypix[oky].max()+1)
slx = slice(xpix[okx].min(), xpix[okx].max()+1)
data = self.model*0.
data[oky, :][:, okx] += full_array[sly, slx]
return data
[docs] def twod_axis_labels(self, wscale=1.e4, limits=None, mpl_axis=None):
"""Set 2D wavelength (x) axis labels based on spectral parameters
Parameters
----------
wscale : float
Scale factor to divide from the wavelength units. The default
value of 1.e4 results in wavelength ticks in microns.
limits : None, list = `[x0, x1, dx]`
Will automatically use the whole wavelength range defined by the
spectrum. To change, specify `limits = [x0, x1, dx]` to
interpolate `self.beam.lam_beam` between x0*wscale and x1*wscale.
mpl_axis : `matplotlib.axes._axes.Axes`
Plotting axis to place the labels, e.g.,
>>> fig = plt.figure()
>>> mpl_axis = fig.add_subplot(111)
Returns
-------
Nothing if `mpl_axis` is supplied, else pixels and wavelengths of the
tick marks.
"""
xarr = np.arange(len(self.lam))
if limits:
xlam = np.arange(limits[0], limits[1], limits[2])
xpix = np.interp(xlam, self.lam/wscale, xarr)
else:
xlam = np.unique(np.asarray(self.lam / 1.e4*10,dtype=int)/10.)
xpix = np.interp(xlam, self.lam/wscale, xarr)
if mpl_axis is None:
return xpix, xlam
else:
mpl_axis.set_xticks(xpix)
mpl_axis.set_xticklabels(xlam)
[docs] def twod_xlim(self, x0, x1=None, wscale=1.e4, mpl_axis=None):
"""Set wavelength (x) axis limits on a 2D spectrum
Parameters
----------
x0 : float or list/tuple of floats
minimum or (min,max) of the plot limits
x1 : float or None
max of the plot limits if x0 is a float
wscale : float
Scale factor to divide from the wavelength units. The default
value of 1.e4 results in wavelength ticks in microns.
mpl_axis : `matplotlib.axes._axes.Axes`
Plotting axis to place the labels.
Returns
-------
Nothing if `mpl_axis` is supplied else pixels the desired wavelength
limits.
"""
if isinstance(x0, list) | isinstance(x0, tuple):
x0, x1 = x0[0], x0[1]
xarr = np.arange(len(self.lam))
xpix = np.interp([x0, x1], self.lam/wscale, xarr)
if mpl_axis:
mpl_axis.set_xlim(xpix)
else:
return xpix
[docs] def x_init_epsf(self, flat_sensitivity=False, psf_params=None, psf_filter='F140W', yoff=0.0, skip=0.5, get_extended=False, seg_mask=True):
"""Initialize ePSF fitting for point sources
TBD
"""
import scipy.sparse
import scipy.ndimage
#print('SKIP: {0}'.format(skip))
EPSF = utils.EffectivePSF()
if psf_params is None:
self.psf_params = [self.total_flux, 0., 0.]
else:
self.psf_params = psf_params
if self.psf_params[0] is None:
self.psf_params[0] = self.total_flux # /photflam_list[psf_filter]
origin = np.array(self.origin) - np.array(self.pad)
self.psf_yoff = yoff
self.psf_filter = psf_filter
self.psf = EPSF.get_ePSF(self.psf_params, sci=self.psf_sci,
ivar=self.psf_ivar, origin=origin,
shape=self.sh, filter=psf_filter,
get_extended=get_extended)
# self.psf_params[0] /= self.psf.sum()
# self.psf /= self.psf.sum()
# Center in detector coords
y0, x0 = np.array(self.sh)/2.-1
if len(self.psf_params) == 2:
xd = x0+self.psf_params[0] + origin[1]
yd = y0+self.psf_params[1] + origin[0]
else:
xd = x0+self.psf_params[1] + origin[1]
yd = y0+self.psf_params[2] + origin[0]
# Get wavelength array
psf_xy_lam = []
psf_ext_lam = []
for i, filter in enumerate(['F105W', 'F125W', 'F160W']):
psf_xy_lam.append(EPSF.get_at_position(x=xd, y=yd, filter=filter))
psf_ext_lam.append(EPSF.extended_epsf[filter])
filt_ix = np.arange(3)
filt_lam = np.array([1.0551, 1.2486, 1.5369])*1.e4
yp_beam, xp_beam = np.indices(self.sh_beam)
xarr = np.arange(0, self.lam_beam.shape[0], skip)
xarr = xarr[xarr <= self.lam_beam.shape[0]-1]
xbeam = np.arange(self.lam_beam.shape[0])*1.
#xbeam += 1.
# yoff = 0 #-0.15
psf_model = self.model*0.
A_psf = []
lam_psf = []
if len(self.psf_params) == 2:
lam_offset = self.psf_params[0] # self.sh[1]/2 - self.psf_params[1] - 1
else:
lam_offset = self.psf_params[1] # self.sh[1]/2 - self.psf_params[1] - 1
self.lam_offset = lam_offset
for xi in xarr:
yi = np.interp(xi, xbeam, self.ytrace_beam)
li = np.interp(xi, xbeam, self.lam_beam)
if len(self.psf_params) == 2:
dx = xp_beam-self.psf_params[0]-xi-x0
dy = yp_beam-self.psf_params[1]-yi+yoff-y0
else:
dx = xp_beam-self.psf_params[1]-xi-x0
dy = yp_beam-self.psf_params[2]-yi+yoff-y0
# wavelength-dependent
ii = np.interp(li, filt_lam, filt_ix, left=-1, right=10)
if ii == -1:
psf_xy_i = psf_xy_lam[0]*1
psf_ext_i = psf_ext_lam[0]*1
elif ii == 10:
psf_xy_i = psf_xy_lam[2]*1
psf_ext_i = psf_ext_lam[2]*1
else:
ni = int(ii)
f = 1-(li-filt_lam[ni])/(filt_lam[ni+1]-filt_lam[ni])
psf_xy_i = f*psf_xy_lam[ni] + (1-f)*psf_xy_lam[ni+1]
psf_ext_i = f*psf_ext_lam[ni] + (1-f)*psf_ext_lam[ni+1]
if not get_extended:
psf_ext_i = None
psf = EPSF.eval_ePSF(psf_xy_i, dx, dy, extended_data=psf_ext_i)
if len(self.psf_params) > 2:
psf *= self.psf_params[0]
#print(xi, psf.sum())
if seg_mask:
segm = nd.maximum_filter((self.seg == self.id)*1., size=7)
#yps, xps = np.indices(self.sh)
seg_i = nd.map_coordinates(segm, np.array([dx+x0, dy+y0]), order=1, mode='constant', cval=0.0, prefilter=True) > 0
else:
seg_i = 1
A_psf.append((psf*seg_i).flatten())
lam_psf.append(li)
# Sensitivity
self.lam_psf = np.array(lam_psf)
#photflam = photflam_list[psf_filter]
photflam = 1
if flat_sensitivity:
psf_sensitivity = np.abs(np.gradient(self.lam_psf))*photflam
else:
sens = self.conf.sens[self.beam]
# so = np.argsort(self.lam_psf)
# s_i = interp.interp_conserve_c(self.lam_psf[so], sens['WAVELENGTH'], sens['SENSITIVITY'], integrate=1)
# psf_sensitivity = s_i*0.
# psf_sensitivity[so] = s_i
if self.MW_F99 is not None:
MWext = 10**(-0.4*(self.MW_F99(sens['WAVELENGTH']*u.AA)))
else:
MWext = 1.
psf_sensitivity = self.get_psf_sensitivity(sens['WAVELENGTH'], sens['SENSITIVITY']*MWext)
self.psf_sensitivity = psf_sensitivity
self.A_psf = scipy.sparse.csr_matrix(np.array(A_psf).T)
# self.init_extended_epsf()
self.PAM_value = self.get_PAM_value()
self.psf_scale_to_data = 1.
self.psf_renorm = 1.
self.renormalize_epsf_model()
self.init_optimal_profile()
[docs] def get_psf_sensitivity(self, wave, sensitivity):
"""
Integrate the sensitivity curve to the wavelengths for the
PSF model
"""
from .utils_numba import interp
so = np.argsort(self.lam_psf)
s_i = interp.interp_conserve_c(self.lam_psf[so], wave, sensitivity, integrate=1)
psf_sensitivity = s_i*0.
psf_sensitivity[so] = s_i
return psf_sensitivity
[docs] def renormalize_epsf_model(self, spectrum_1d=None, verbose=False):
"""
Ensure normalization correct
"""
from .utils_numba import interp
if not hasattr(self, 'A_psf'):
print('ePSF not initialized')
return False
if spectrum_1d is None:
dl = 0.1
flat_x = np.arange(self.lam.min()-10, self.lam.max()+10, dl)
flat_y = flat_x*0.+1.e-17
spectrum_1d = [flat_x, flat_y]
tab = self.conf.sens[self.beam]
if self.MW_F99 is not None:
MWext = 10**(-0.4*(self.MW_F99(tab['WAVELENGTH']*u.AA)))
else:
MWext = 1.
sens_i = interp.interp_conserve_c(spectrum_1d[0], tab['WAVELENGTH'], tab['SENSITIVITY']*MWext, integrate=1, left=0, right=0)
total_sens = np.trapz(spectrum_1d[1]*sens_i/np.gradient(spectrum_1d[0]), spectrum_1d[0])
m = self.compute_model_psf(spectrum_1d=spectrum_1d, is_cgs=True, in_place=False).reshape(self.sh_beam)
#m2 = self.compute_model(spectrum_1d=[flat_x, flat_y], is_cgs=True, in_place=False).reshape(self.sh_beam)
renorm = total_sens / m.sum()
self.psf_renorm = renorm
# Scale model to data, depends on Pixel Area Map and PSF normalization
scale_to_data = self.PAM_value # * (self.psf_params[0]/0.975)
self.psf_scale_to_data = scale_to_data
renorm /= scale_to_data # renorm PSF
if verbose:
print('Renorm ePSF model: {0:0.3f}'.format(renorm))
self.A_psf *= renorm
[docs] def get_PAM_value(self, verbose=False):
"""
Apply Pixel Area Map correction to WFC3 effective PSF model
http://www.stsci.edu/hst/wfc3/pam/pixel_area_maps
"""
confp = self.conf.conf_dict
if ('INSTRUMENT' in confp) & ('CAMERA' in confp):
instr = '{0}-{1}'.format(confp['INSTRUMENT'], confp['CAMERA'])
if instr != 'WFC3-IR':
return 1
else:
return 1
try:
with pyfits.open(os.getenv('iref')+'ir_wfc3_map.fits') as pam:
pam_data = pam[1].data
pam_value = pam_data[int(self.yc-self.pad[0]),
int(self.xc-self.pad[1])]
pam.close()
except:
pam_value = 1
if verbose:
msg = 'PAM correction at x={0}, y={1}: {2:.3f}'
print(msg.format(self.xc-self.pad[1],
self.yc-self.pad[0],
pam_value))
return pam_value
[docs] def init_extended_epsf(self):
"""
Hacky code for adding extended component of the EPSFs
"""
ext_file = os.path.join(GRIZLI_PATH, 'CONF',
'ePSF_extended_splines.npy')
if not os.path.exists(ext_file):
return False
bg_splines = np.load(ext_file, allow_pickle=True)[0]
spline_waves = np.array(list(bg_splines.keys()))
spline_waves.sort()
spl_ix = np.arange(len(spline_waves))
yarr = np.arange(self.sh_beam[0]) - self.sh_beam[0]/2.+1
dy = self.psf_params[2]
spl_data = self.model * 0.
for i in range(self.sh_beam[1]):
dy_i = dy + self.ytrace[i]
x_i = np.interp(self.lam[i], spline_waves, spl_ix)
if (x_i == 0) | (x_i == len(bg_splines)-1):
spl_data[:, i] = bg_splines[spline_waves[int(x_i)]](yarr-dy_i)
else:
f = x_i-int(x_i)
sp = bg_splines[spline_waves[int(x_i)]](yarr-dy_i)*(1-f)
sp += bg_splines[spline_waves[int(x_i)+1]](yarr-dy_i)*f
spl_data[:, i] = sp
self.ext_psf_data = np.maximum(spl_data, 0)
[docs] def compute_model_psf(self, id=None, spectrum_1d=None, in_place=True, is_cgs=False, apply_sensitivity=True):
"""
Compute model with PSF morphology template
"""
from .utils_numba import interp
if spectrum_1d is None:
#modelf = np.array(self.A_psf.sum(axis=1)).flatten()
#model = model.reshape(self.sh_beam)
coeffs = np.ones(self.A_psf.shape[1])
if not is_cgs:
coeffs *= self.total_flux
else:
dx = np.diff(self.lam_psf)[0]
if dx < 0:
coeffs = interp.interp_conserve_c(self.lam_psf[::-1],
spectrum_1d[0],
spectrum_1d[1])[::-1]
else:
coeffs = interp.interp_conserve_c(self.lam_psf,
spectrum_1d[0],
spectrum_1d[1])
if not is_cgs:
coeffs *= self.total_flux
modelf = self.A_psf.dot(coeffs*self.psf_sensitivity).astype(np.float32)
model = modelf.reshape(self.sh_beam)
# if hasattr(self, 'ext_psf_data'):
# model += self.ext_psf_data*model.sum(axis=0)
# modelf = model.flatten()
# model = modelf.reshape(self.sh_beam)
if in_place:
self.spectrum_1d = spectrum_1d
self.is_cgs = is_cgs
self.modelf = modelf # .flatten()
self.model = model
#self.modelf = self.model.flatten()
return True
else:
return modelf # .flatten()
[docs]class ImageData(object):
"""Container for image data with WCS, etc."""
def __init__(self, sci=None, err=None, dq=None,
header=None, wcs=None, photflam=1., photplam=1.,
origin=[0, 0], pad=(0,0), process_jwst_header=True,
instrument='WFC3', filter='G141', pupil=None, module=None,
hdulist=None, restore_medfilt=False,
sci_extn=1, fwcpos=None):
"""
Parameters
----------
sci : `~numpy.ndarray`
Science data
err, dq : `~numpy.ndarray` or None
Uncertainty and DQ data. Defaults to zero if None
header : `~astropy.io.fits.Header`
Associated header with `data` that contains WCS information
wcs : `~astropy.wcs.WCS` or None
WCS solution to use. If `None` will derive from the `header`.
photflam : float
Multiplicative conversion factor to scale `data` to set units
to f_lambda flux density. If data is grism spectra, then use
photflam=1
origin : [int, int]
Origin of lower left pixel in detector coordinates
pad : int,int
Padding to apply to the image dimensions in numpy axis order
process_jwst_header : bool
If the image is detected as coming from JWST NIRISS or NIRCAM,
generate the necessary header WCS keywords
instrument : str
Instrument where the image came from
filter : str
Filter from the image header. For WFC3 and NIRISS this is the
dispersing element
pupil : str
Pupil from the image header (JWST instruments). For NIRISS this
is the blocking filter and for NIRCAM this is the dispersing
element
module : str
Instrument module for NIRCAM ('A' or 'B')
hdulist : `~astropy.io.fits.HDUList`, optional
If specified, read `sci`, `err`, `dq` from the HDU list from a
FITS file, e.g., WFC3 FLT.
sci_extn : int
Science EXTNAME to read from the HDUList, for example,
`sci` = hdulist['SCI',`sci_extn`].
fwcpos : float
Filter wheel encoder position (NIRISS)
Attributes
----------
parent_file : str
Filename of the parent from which the data were extracted
data : dict
Dictionary to store pixel data, with keys 'SCI', 'DQ', and 'ERR'.
If a reference image has been supplied and processed, will also
have an entry 'REF'. The data arrays can also be addressed with
the `__getitem__` method, i.e.,
>>> self = ImageData(...)
>>> print np.median(self['SCI'])
pad : int, int
Additional padding around the nominal image dimensions in
numpy array order
wcs : `~astropy.wcs.WCS`
WCS of the data array
header : `~astropy.io.fits.Header`
FITS header
filter, instrument, photflam, photplam, APZP : str, float
Parameters taken from the header
ref_file, ref_photlam, ref_photplam, ref_filter : str, float
Corresponding parameters for the reference image, if necessary.
"""
import copy
med_filter = None
bkg_array = None
# Easy way, get everything from an image HDU list
if isinstance(hdulist, pyfits.HDUList):
if ('REF', sci_extn) in hdulist:
ref_h = hdulist['REF', sci_extn].header
ref_data = hdulist['REF', sci_extn].data/ref_h['PHOTFLAM']
ref_data = np.asarray(ref_data,dtype=np.float32)
ref_file = ref_h['REF_FILE']
ref_photflam = 1.
ref_photplam = ref_h['PHOTPLAM']
#ref_filter = ref_h['FILTER']
ref_filter = utils.parse_filter_from_header(ref_h)
else:
ref_data = None
if ('SCI', sci_extn) in hdulist:
sci = np.asarray(hdulist['SCI', sci_extn].data,dtype=np.float32)
err = np.asarray(hdulist['ERR', sci_extn].data,dtype=np.float32)
dq = np.asarray(hdulist['DQ', sci_extn].data,dtype=np.int16)
if ('MED',sci_extn) in hdulist:
mkey = ('MED',sci_extn)
med_filter = np.asarray(hdulist[mkey].data,dtype=np.float32)
if restore_medfilt:
print('xxx put med filter back in')
sci += med_filter
if ('BKG',sci_extn) in hdulist:
mkey = ('BKG',sci_extn)
bkg_array = np.asarray(hdulist[mkey].data,dtype=np.float32)
base_extn = ('SCI', sci_extn)
else:
if ref_data is None:
raise KeyError('No SCI or REF extensions found')
# Doesn't have SCI, get from ref
sci = err = ref_data*0.+1
dq = np.zeros(sci.shape, dtype=np.int16)
base_extn = ('REF', sci_extn)
if 'ORIGINX' in hdulist[base_extn].header:
h0 = hdulist[base_extn].header
origin = [h0['ORIGINY'], h0['ORIGINX']]
else:
origin = [0, 0]
self.sci_extn = sci_extn
header = hdulist[base_extn].header.copy()
if 'PARENT' in header:
self.parent_file = header['PARENT']
else:
self.parent_file = hdulist.filename()
if 'CPDIS1' in header:
if 'Lookup' in header['CPDIS1']:
self.wcs_is_lookup = True
else:
self.wcs_is_lookup = False
else:
self.wcs_is_lookup = False
status = False
for ext in [base_extn, 0]:
h = hdulist[ext].header
if 'INSTRUME' in h:
status = True
break
if not status:
msg = ('Couldn\'t find \'INSTRUME\' keyword in the headers' +
' of extensions 0 or (SCI,{0:d})'.format(sci_extn))
raise KeyError(msg)
instrument = h['INSTRUME']
filter = utils.parse_filter_from_header(h, filter_only=True)
if 'PUPIL' in h:
pupil = h['PUPIL']
if 'MODULE' in h:
module = h['MODULE']
else:
module = None
if 'PHOTPLAM' in h:
photplam = h['PHOTPLAM']
elif filter in photplam_list:
photplam = photplam_list[filter]
else:
photplam = 1
if 'PHOTFLAM' in h:
photflam = h['PHOTFLAM']
elif filter in photflam_list:
photflam = photflam_list[filter]
elif 'PHOTUJA2' in header:
# JWST calibrated products
per_pix = header['PIXAR_SR']
if header['BUNIT'].strip() == 'MJy/sr':
photfnu = per_pix*1e6
else:
photfnu = 1./(header['PHOTMJSR']*1.e6)*per_pix
photflam = photfnu/1.e23*3.e18/photplam**2
else:
photflam = 1.
# For NIRISS
if 'FWCPOS' in h:
fwcpos = h['FWCPOS']
self.mdrizsky = 0.
if 'MDRIZSKY' in header:
#sci -= header['MDRIZSKY']
self.mdrizsky = header['MDRIZSKY']
# ACS bunit
#self.exptime = 1.
if 'EXPTIME' in hdulist[0].header:
self.exptime = hdulist[0].header['EXPTIME']
else:
self.exptime = hdulist[0].header['EFFEXPTM']
# if 'BUNIT' in header:
# if header['BUNIT'] == 'ELECTRONS':
# self.exptime = hdulist[0].header['EXPTIME']
# # sci /= self.exptime
# # err /= self.exptime
sci = (sci-self.mdrizsky)
if 'BUNIT' in header:
if header['BUNIT'] == 'ELECTRONS':
sci /= self.exptime
err /= self.exptime
if filter.startswith('G'):
photflam = 1
if (instrument == 'NIRCAM') & (pupil is not None):
if pupil.startswith('G'):
photflam = 1
if 'PAD' in header:
pad = [header['PAD'], header['PAD']]
elif ('PADX' in header) & ('PADY' in header):
pad = [header['PADY'], header['PADX']]
else:
pad = [0,0]
self.grow = 1
if 'GROW' in header:
self.grow = header['GROW']
else:
if sci is None:
sci = np.zeros((1014, 1014))
self.parent_file = 'Unknown'
self.sci_extn = None
self.grow = 1
ref_data = None
if 'EXPTIME' in header:
self.exptime = header['EXPTIME']
else:
self.exptime = 1.
if 'MDRIZSKY' in header:
self.mdrizsky = header['MDRIZSKY']
else:
self.mdrizsky = 0.
if 'CPDIS1' in header:
if 'Lookup' in header['CPDIS1']:
self.wcs_is_lookup = True
else:
self.wcs_is_lookup = False
else:
self.wcs_is_lookup = False
self.is_slice = False
# Array parameters
if isinstance(pad, int):
self.pad = [pad, pad]
else:
self.pad = pad
self.origin = origin
self.fwcpos = fwcpos # NIRISS
self.MW_EBV = 0.
self.data = OrderedDict()
self.data['SCI'] = sci*photflam
self.sh = np.array(self.data['SCI'].shape)
# Header-like parameters
self.filter = filter
self.pupil = pupil
if (instrument == 'NIRCAM'):
# Fallback if module not specified
if module is None:
if 'MODULE' not in header:
self.module = 'A'
else:
self.module = header['MODULE']
else:
self.module = module
else:
self.module = module
self.instrument = instrument
self.header = header
if 'ISCUTOUT' in self.header:
self.is_slice = self.header['ISCUTOUT']
self.header['EXPTIME'] = self.exptime
self.photflam = photflam
self.photplam = photplam
self.ABZP = (0*np.log10(self.photflam) - 21.10 -
5*np.log10(self.photplam) + 18.6921)
self.thumb_extension = 'SCI'
if err is None:
self.data['ERR'] = np.zeros_like(self.data['SCI'])
else:
self.data['ERR'] = err*photflam
if self.data['ERR'].shape != tuple(self.sh):
raise ValueError('err and sci arrays have different shapes!')
if dq is None:
self.data['DQ'] = np.zeros_like(self.data['SCI'], dtype=np.int16)
else:
self.data['DQ'] = dq
if self.data['DQ'].shape != tuple(self.sh):
raise ValueError('err and dq arrays have different shapes!')
if ref_data is None:
self.data['REF'] = None
self.ref_file = None
self.ref_photflam = None
self.ref_photplam = None
self.ref_filter = None
else:
self.data['REF'] = ref_data
self.ref_file = ref_file
self.ref_photflam = ref_photflam
self.ref_photplam = ref_photplam
self.ref_filter = ref_filter
if med_filter is not None:
self.data['MED'] = med_filter
if bkg_array is not None:
self.data['BKG'] = bkg_array
self.wcs = None
# if (instrument in ['NIRISS', 'NIRCAM']) & (~self.is_slice):
# if process_jwst_header:
# self.update_jwst_wcsheader(hdulist)
if self.header is not None:
if wcs is None:
self.get_wcs()
else:
self.wcs = wcs.copy()
if not hasattr(self.wcs, 'pixel_shape'):
self.wcs.pixel_shape = self.wcs._naxis1, self.wcs._naxis2
else:
self.header = pyfits.Header()
# Detector chip
if 'CCDCHIP' in self.header:
self.ccdchip = self.header['CCDCHIP']
else:
self.ccdchip = 1
# Galactic extinction
if 'MW_EBV' in self.header:
self.MW_EBV = self.header['MW_EBV']
else:
self.MW_EBV = 0.
[docs] def unset_dq(self):
"""Flip OK data quality bits using utils.mod_dq_bits
OK bits are defined as
>>> okbits_instrument = {'WFC3': 32+64+512, # blob OK
'NIRISS': 1+2+4,
'NIRCAM': 1+2+4,
'WFIRST': 0,
'WFI': 0}
"""
okbits_instrument = {'WFC3': 32+64+512, # blob OK
'NIRISS': 1+2+4, #+4096+4100+18432+18436+1024+16384+1,
'NIRCAM': 1+2+4,
'WFIRST': 0,
'WFI': 0}
if self.instrument not in okbits_instrument:
okbits = 1
else:
okbits = okbits_instrument[self.instrument]
self.data['DQ'] = utils.mod_dq_bits(self.data['DQ'], okbits=okbits)
[docs] def flag_negative(self, sigma=-3):
"""Flag negative data values with dq=4
Parameters
----------
sigma : float
Threshold for setting bad data
Returns
-------
n_negative : int
Number of flagged negative pixels
If `self.data['ERR']` is zeros, do nothing.
"""
if self.data['ERR'].max() == 0:
return 0
bad = self.data['SCI'] < sigma*self.data['ERR']
self.data['DQ'][bad] |= 4
return bad.sum()
[docs] def get_wcs(self, pc2cd=False):
"""Get WCS from header"""
import numpy.linalg
import stwcs
if self.wcs_is_lookup:
if 'CCDCHIP' in self.header:
ext = {1: 2, 2: 1}[self.header['CCDCHIP']]
else:
ext = self.header['EXTVER']
if os.path.exists(self.parent_file):
with pyfits.open(self.parent_file) as fobj:
wcs = stwcs.wcsutil.hstwcs.HSTWCS(fobj=fobj,
ext=('SCI', ext))
if np.max(self.pad) > 0:
wcs = self.add_padding_to_wcs(wcs, pad=self.pad)
else:
# Get WCS from a stripped wcs.fits file (from self.save_wcs)
# already padded.
wcsfile = self.parent_file.replace('.fits',
'.{0:02d}.wcs.fits'.format(ext))
with pyfits.open(wcsfile) as fobj:
fh = fobj[0].header
if fh['NAXIS'] == 0:
fh['NAXIS'] = 2
fh['NAXIS1'] = int(fh['CRPIX1']*2)
fh['NAXIS2'] = int(fh['CRPIX2']*2)
wcs = stwcs.wcsutil.hstwcs.HSTWCS(fobj=fobj, ext=0)
# Object is a cutout
if self.is_slice:
slx = slice(self.origin[1], self.origin[1]+self.sh[1])
sly = slice(self.origin[0], self.origin[0]+self.sh[0])
wcs = self.get_slice_wcs(wcs, slx=slx, sly=sly)
else:
fobj = None
wcs = pywcs.WCS(self.header, relax=True, fobj=fobj)
if not hasattr(wcs, 'pscale'):
wcs.pscale = utils.get_wcs_pscale(wcs)
self.wcs = wcs
if not hasattr(self.wcs, 'pixel_shape'):
self.wcs.pixel_shape = self.wcs._naxis1, self.wcs._naxis2
[docs] @staticmethod
def add_padding_to_wcs(wcs_in, pad=(64,256)):
"""Pad the appropriate WCS keywords
Parameters
----------
wcs_in : `~astropy.wcs.WCS`
Input WCS
pad : int, int
Number of pixels to pad, in array order (axis2, axis1)
Returns
-------
wcs_out : `~astropy.wcs.WCS`
Padded WCS
"""
wcs = wcs_in.deepcopy()
is_new = True
for attr in ['naxis1', '_naxis1']:
if hasattr(wcs, attr):
is_new = False
value = wcs.__getattribute__(attr)
if value is not None:
wcs.__setattr__(attr, value+2*pad[1])
for attr in ['naxis2', '_naxis2']:
if hasattr(wcs, attr):
is_new = False
value = wcs.__getattribute__(attr)
if value is not None:
wcs.__setattr__(attr, value+2*pad[0])
# Handle changing astropy.wcs.WCS attributes
if is_new:
#for i in range(len(wcs._naxis)):
# wcs._naxis[i] += 2*pad
wcs._naxis[0] += 2*pad[1]
wcs._naxis[1] += 2*pad[0]
wcs.naxis1, wcs.naxis2 = wcs._naxis
else:
wcs.naxis1 = wcs._naxis1
wcs.naxis2 = wcs._naxis2
wcs.wcs.crpix[0] += pad[1]
wcs.wcs.crpix[1] += pad[0]
# Pad CRPIX for SIP
for wcs_ext in [wcs.sip]:
if wcs_ext is not None:
wcs_ext.crpix[0] += pad[1]
wcs_ext.crpix[1] += pad[0]
# Pad CRVAL for Lookup Table, if necessary (e.g., ACS)
for wcs_ext in [wcs.cpdis1, wcs.cpdis2, wcs.det2im1, wcs.det2im2]:
if wcs_ext is not None:
wcs_ext.crval[0] += pad[1]
wcs_ext.crval[1] += pad[0]
return wcs
[docs] def add_padding(self, pad=(64,256)):
"""Pad the data array and update WCS keywords"""
if isinstance(pad, int):
_pad = [pad, pad]
else:
_pad = pad
# Update data array
new_sh = np.array([s for s in self.sh])
new_sh[0] += 2*pad[0]
new_sh[1] += 2*pad[1]
for key in ['SCI', 'ERR', 'DQ', 'REF']:
if key not in self.data:
continue
else:
if self.data[key] is None:
continue
data = self.data[key]
new_data = np.zeros(new_sh, dtype=data.dtype)
new_data[pad[0]:-pad[0], pad[1]:-pad[1]] += data
self.data[key] = new_data
self.sh = new_sh
for i in range(2):
self.pad[i] += _pad[i]
# Padded image dimensions
self.header['NAXIS1'] += 2*_pad[1]
self.header['NAXIS2'] += 2*_pad[0]
self.header['CRPIX1'] += _pad[1]
self.header['CRPIX2'] += _pad[0]
# Add padding to WCS
self.wcs = self.add_padding_to_wcs(self.wcs, pad=_pad)
if not hasattr(self.wcs, 'pixel_shape'):
self.wcs.pixel_shape = self.wcs._naxis1, self.wcs._naxis2
[docs] def shrink_large_hdu(self, hdu=None, extra=100, verbose=False):
"""Shrink large image mosaic to speed up blotting
Parameters
----------
hdu : `~astropy.io.fits.ImageHDU`
Input reference HDU
extra : int
Extra border to put around `self.data` WCS to ensure the reference
image is large enough to encompass the distorted image
Returns
-------
new_hdu : `~astropy.io.fits.ImageHDU`
Image clipped to encompass `self.data['SCI']` + margin of `extra`
pixels.
Make a cutout of the larger reference image around the desired FLT
image to make blotting faster for large reference images.
"""
ref_wcs = pywcs.WCS(hdu.header)
# Borders of the flt frame
naxis = [self.header['NAXIS1'], self.header['NAXIS2']]
xflt = [-extra, naxis[0]+extra, naxis[0]+extra, -extra]
yflt = [-extra, -extra, naxis[1]+extra, naxis[1]+extra]
raflt, deflt = self.wcs.all_pix2world(xflt, yflt, 0)
xref, yref = np.asarray(ref_wcs.all_world2pix(raflt, deflt, 0),dtype=int)
ref_naxis = [hdu.header['NAXIS1'], hdu.header['NAXIS2']]
# Slices of the reference image
xmi = np.maximum(0, xref.min())
xma = np.minimum(ref_naxis[0], xref.max())
slx = slice(xmi, xma)
ymi = np.maximum(0, yref.min())
yma = np.minimum(ref_naxis[1], yref.max())
sly = slice(ymi, yma)
if ((xref.min() < 0) | (yref.min() < 0) |
(xref.max() > ref_naxis[0]) | (yref.max() > ref_naxis[1])):
if verbose:
msg = 'Image cutout: x={0}, y={1} [Out of range]'
print(msg.format(slx, sly))
return hdu
else:
if verbose:
print('Image cutout: x={0}, y={1}'.format(slx, sly))
# Sliced subimage
slice_wcs = ref_wcs.slice((sly, slx))
slice_header = hdu.header.copy()
#hwcs = slice_wcs.to_header(relax=True)
hwcs = utils.to_header(slice_wcs, relax=True)
for k in hwcs.keys():
if not k.startswith('PC'):
slice_header[k] = hwcs[k]
slice_data = hdu.data[sly, slx]*1
new_hdu = pyfits.ImageHDU(data=slice_data, header=slice_header)
return new_hdu
[docs] def expand_hdu(self, hdu=None, verbose=True):
"""TBD
"""
ref_wcs = pywcs.WCS(hdu.header)
# Borders of the flt frame
naxis = [self.header['NAXIS1'], self.header['NAXIS2']]
xflt = [-self.pad[1], naxis[0]+self.pad[1],
naxis[0]+self.pad[1], -self.pad[1]]
yflt = [-self.pad[0], -self.pad[0],
naxis[1]+self.pad[0], naxis[1]+self.pad[0]]
raflt, deflt = self.wcs.all_pix2world(xflt, yflt, 0)
xref, yref = np.asarray(ref_wcs.all_world2pix(raflt, deflt, 0),dtype=int)
ref_naxis = [hdu.header['NAXIS1'], hdu.header['NAXIS2']]
pad_min = np.minimum(xref.min(), yref.min())
pad_max = np.maximum((xref-ref_naxis[0]).max(),
(yref-ref_naxis[1]).max())
if (pad_min > 0) & (pad_max < 0):
# do nothing
return hdu
pad = np.maximum(np.abs(pad_min), pad_max) + 64
if verbose:
msg = '{0} / Pad ref HDU with {1:d} pixels'
print(msg.format(self.parent_file, pad))
# Update data array
sh = hdu.data.shape
new_sh = np.array(sh) + 2*pad
new_data = np.zeros(new_sh, dtype=hdu.data.dtype)
new_data[pad:-pad, pad:-pad] += hdu.data
header = hdu.header.copy()
# Padded image dimensions
header['NAXIS1'] += 2*pad
header['NAXIS2'] += 2*pad
# Add padding to WCS
header['CRPIX1'] += pad
header['CRPIX2'] += pad
new_hdu = pyfits.ImageHDU(data=new_data, header=header)
return new_hdu
[docs] def blot_from_hdu(self, hdu=None, segmentation=False, grow=3,
interp='nearest'):
"""Blot a rectified reference image to detector frame
Parameters
----------
hdu : `~astropy.io.fits.ImageHDU`
HDU of the reference image
segmentation : bool, False
If True, treat the reference image as a segmentation image and
preserve the integer values in the blotting.
If specified as number > 1, then use `~grizli.utils.blot_nearest_exact`
rather than a hacky pixel area ratio method to blot integer
segmentation maps.
grow : int, default=3
Number of pixels to dilate the segmentation regions
interp : str,
Form of interpolation to use when blotting float image pixels.
Valid options: {'nearest', 'linear', 'poly3', 'poly5' (default), 'spline3', 'sinc'}
Returns
-------
blotted : `np.ndarray`
Blotted array with the same shape and WCS as `self.data['SCI']`.
"""
import astropy.wcs
from drizzlepac import astrodrizzle
#ref = pyfits.open(refimage)
if hdu.data.dtype.type != np.float32:
#hdu.data = np.asarray(hdu.data,dtype=np.float32)
refdata = np.asarray(hdu.data,dtype=np.float32)
else:
refdata = hdu.data
if 'ORIENTAT' in hdu.header.keys():
hdu.header.remove('ORIENTAT')
if segmentation:
seg_ones = np.asarray(refdata > 0,dtype=np.float32)-1
ref_wcs = pywcs.WCS(hdu.header, relax=True)
flt_wcs = self.wcs.copy()
# Fix some wcs attributes that might not be set correctly
for wcs in [ref_wcs, flt_wcs]:
if hasattr(wcs, '_naxis1'):
wcs.naxis1 = wcs._naxis1
wcs.naxis2 = wcs._naxis2
else:
wcs._naxis1, wcs._naxis2 = wcs._naxis
if (not hasattr(wcs.wcs, 'cd')) & hasattr(wcs.wcs, 'pc'):
wcs.wcs.cd = wcs.wcs.pc
if hasattr(wcs, 'idcscale'):
if wcs.idcscale is None:
wcs.idcscale = np.mean(np.sqrt(np.sum(wcs.wcs.cd**2, axis=0))*3600.) # np.sqrt(np.sum(wcs.wcs.cd[0,:]**2))*3600.
else:
#wcs.idcscale = np.sqrt(np.sum(wcs.wcs.cd[0,:]**2))*3600.
wcs.idcscale = np.mean(np.sqrt(np.sum(wcs.wcs.cd**2, axis=0))*3600.) # np.sqrt(np.sum(wcs.wcs.cd[0,:]**2))*3600.
wcs.pscale = utils.get_wcs_pscale(wcs)
if segmentation:
# Handle segmentation images a bit differently to preserve
# integers.
# +1 here is a hack for some memory issues
if segmentation*1 == 1:
seg_interp = 'nearest'
blotted_ones = astrodrizzle.ablot.do_blot(seg_ones+1, ref_wcs,
flt_wcs, 1, coeffs=True,
interp=seg_interp,
sinscl=1.0, stepsize=10, wcsmap=None)
blotted_seg = astrodrizzle.ablot.do_blot(refdata*1., ref_wcs,
flt_wcs, 1, coeffs=True,
interp=seg_interp,
sinscl=1.0, stepsize=10, wcsmap=None)
blotted_ones[blotted_ones == 0] = 1
#pixel_ratio = (flt_wcs.idcscale / ref_wcs.idcscale)**2
#in_seg = np.abs(blotted_ones - pixel_ratio) < 1.e-2
ratio = np.round(blotted_seg/blotted_ones)
seg = nd.maximum_filter(ratio, size=grow,
mode='constant', cval=0)
ratio[ratio == 0] = seg[ratio == 0]
blotted = ratio
else:
blotted = utils.blot_nearest_exact(refdata, ref_wcs, flt_wcs,
verbose=True, stepsize=-1,
scale_by_pixel_area=False,
wcs_mask=True,
fill_value=0)
else:
# Floating point data
blotted = astrodrizzle.ablot.do_blot(refdata, ref_wcs, flt_wcs, 1,
coeffs=True, interp=interp, sinscl=1.0,
stepsize=10, wcsmap=None)
return blotted
[docs] @staticmethod
def get_slice_wcs(wcs, slx=slice(480, 520), sly=slice(480, 520)):
"""Get slice of a WCS including higher orders like SIP and DET2IM
The normal `~astropy.wcs.wcs.WCS` `slice` method doesn't apply the
slice to all of the necessary keywords. For example, SIP WCS also
has a `CRPIX` reference pixel that needs to be offset along with
the main `CRPIX`.
Parameters
----------
slx, sly : slice
Slices in x and y dimensions to extract
"""
NX = slx.stop - slx.start
NY = sly.stop - sly.start
slice_wcs = wcs.slice((sly, slx))
if hasattr(slice_wcs, '_naxis1'):
slice_wcs.naxis1 = slice_wcs._naxis1 = NX
slice_wcs.naxis2 = slice_wcs._naxis2 = NY
else:
slice_wcs._naxis = [NX, NY]
slice_wcs._naxis1, slice_wcs._naxis2 = NX, NY
if hasattr(slice_wcs, 'sip'):
if slice_wcs.sip is not None:
for c in [0, 1]:
slice_wcs.sip.crpix[c] = slice_wcs.wcs.crpix[c]
ACS_CRPIX = [4096/2, 2048/2] # ACS
dx_crpix = slice_wcs.wcs.crpix[0] - ACS_CRPIX[0]
dy_crpix = slice_wcs.wcs.crpix[1] - ACS_CRPIX[1]
for ext in ['cpdis1', 'cpdis2', 'det2im1', 'det2im2']:
if hasattr(slice_wcs, ext):
wcs_ext = slice_wcs.__getattribute__(ext)
if wcs_ext is not None:
wcs_ext.crval[0] += dx_crpix
wcs_ext.crval[1] += dy_crpix
slice_wcs.__setattr__(ext, wcs_ext)
return slice_wcs
[docs] def get_slice(self, slx=slice(480, 520), sly=slice(480, 520),
get_slice_header=True):
"""Return cutout version of the `ImageData` object
Parameters
----------
slx, sly : slice
Slices in x and y dimensions to extract
get_slice_header : bool
Compute the full header of the slice. This takes a bit of time
and isn't necessary in all cases so can be omitted if only the
sliced data are of interest and the header isn't needed.
Returns
-------
slice_obj : `ImageData`
New `ImageData` object of the sliced subregion
"""
origin = [sly.start, slx.start]
NX = slx.stop - slx.start
NY = sly.stop - sly.start
# Test dimensions
if (origin[0] < 0) | (origin[0]+NY > self.sh[0]):
raise ValueError('Out of range in y')
if (origin[1] < 0) | (origin[1]+NX > self.sh[1]):
raise ValueError('Out of range in x')
# Sliced subimage
# sly = slice(origin[0], origin[0]+N)
# slx = slice(origin[1], origin[1]+N)
slice_origin = [self.origin[i] + origin[i] for i in range(2)]
slice_wcs = self.get_slice_wcs(self.wcs, slx=slx, sly=sly)
# slice_wcs = self.wcs.slice((sly, slx))
#slice_wcs.naxis1 = slice_wcs._naxis1 = NX
#slice_wcs.naxis2 = slice_wcs._naxis2 = NY
# Getting the full header can be slow as there appears to
# be substantial overhead with header.copy() and wcs.to_header()
if get_slice_header:
slice_header = self.header.copy()
slice_header['NAXIS1'] = NX
slice_header['NAXIS2'] = NY
# Sliced WCS keywords
hwcs = utils.to_header(slice_wcs, relax=True)
for k in hwcs:
if not k.startswith('PC'):
slice_header[k] = hwcs[k]
else:
cd = k.replace('PC', 'CD')
slice_header[cd] = hwcs[k]
else:
slice_header = pyfits.Header()
# Generate new object
if (self.data['REF'] is not None) & (self.data['SCI'] is None):
_sci = _err = _dq = None
else:
_sci = self.data['SCI'][sly, slx]/self.photflam
_err = self.data['ERR'][sly, slx]/self.photflam
_dq = self.data['DQ'][sly, slx]*1
slice_obj = ImageData(sci=_sci, err=_err, dq=_dq,
header=slice_header, wcs=slice_wcs,
photflam=self.photflam, photplam=self.photplam,
origin=slice_origin, instrument=self.instrument,
filter=self.filter, pupil=self.pupil,
module=self.module,
process_jwst_header=False)
slice_obj.ref_photflam = self.ref_photflam
slice_obj.ref_photplam = self.ref_photplam
slice_obj.ref_filter = self.ref_filter
slice_obj.mdrizsky = self.mdrizsky
slice_obj.exptime = self.exptime
slice_obj.ABZP = self.ABZP
slice_obj.thumb_extension = self.thumb_extension
if self.data['REF'] is not None:
slice_obj.data['REF'] = self.data['REF'][sly, slx]*1
else:
slice_obj.data['REF'] = None
if 'MED' in self.data:
slice_obj.data['MED'] = self.data['MED'][sly, slx]*1
if 'BKG' in self.data:
slice_obj.data['BKG'] = self.data['BKG'][sly, slx]*1
slice_obj.grow = self.grow
slice_obj.pad = self.pad
slice_obj.parent_file = self.parent_file
slice_obj.ref_file = self.ref_file
slice_obj.sci_extn = self.sci_extn
slice_obj.is_slice = True
# if hasattr(slice_obj.wcs, 'sip'):
# if slice_obj.wcs.sip is not None:
# for c in [0,1]:
# slice_obj.wcs.sip.crpix[c] = slice_obj.wcs.wcs.crpix[c]
#
# ACS_CRPIX = [4096/2,2048/2] # ACS
# dx_crpix = slice_obj.wcs.wcs.crpix[0] - ACS_CRPIX[0]
# dy_crpix = slice_obj.wcs.wcs.crpix[1] - ACS_CRPIX[1]
# for ext in ['cpdis1','cpdis2','det2im1','det2im2']:
# if hasattr(slice_obj.wcs, ext):
# wcs_ext = slice_obj.wcs.__getattribute__(ext)
# if wcs_ext is not None:
# wcs_ext.crval[0] += dx_crpix
# wcs_ext.crval[1] += dy_crpix
# slice_obj.wcs.__setattr__(ext, wcs_ext)
return slice_obj # , slx, sly
[docs] def get_HDUList(self, extver=1):
"""Convert attributes and data arrays to a `~astropy.io.fits.HDUList`
Parameters
----------
extver : int, float, str
value to use for the 'EXTVER' header keyword. For example, with
extver=1, the science extension can be addressed with the index
`HDU['SCI',1]`.
returns : `~astropy.io.fits.HDUList`
HDUList with header keywords copied from `self.header` along with
keywords for additional attributes. Will have `ImageHDU`
extensions 'SCI', 'ERR', and 'DQ', as well as 'REF' if a reference
file had been supplied.
"""
h = self.header.copy()
h['EXTVER'] = extver # self.filter #extver
h['FILTER'] = self.filter, 'element selected from filter wheel'
h['PUPIL'] = self.pupil, 'element selected from pupil wheel'
h['INSTRUME'] = (self.instrument,
'identifier for instrument used to acquire data')
if self.module is not None:
h['MODULE'] = self.module, 'Instrument module'
h['PHOTFLAM'] = (self.photflam,
'inverse sensitivity, ergs/cm2/Ang/electron')
h['PHOTPLAM'] = self.photplam, 'Pivot wavelength (Angstroms)'
h['PARENT'] = self.parent_file, 'Parent filename'
h['SCI_EXTN'] = self.sci_extn, 'EXTNAME of the science data'
h['ISCUTOUT'] = self.is_slice, 'Arrays are sliced from larger image'
h['ORIGINX'] = self.origin[1], 'Origin from parent image, x'
h['ORIGINY'] = self.origin[0], 'Origin from parent image, y'
if isinstance(self.pad, int):
_pad = (self.pad, self.pad)
else:
_pad = self.pad
h['PADX'] = (_pad[1], 'Image padding used axis1')
h['PADY'] = (_pad[0], 'Image padding used axis2')
hdu = []
exptime_corr = 1.
if 'BUNIT' in self.header:
if self.header['BUNIT'] == 'ELECTRONS':
exptime_corr = self.exptime
# Put back into original units
sci_data = self['SCI']*exptime_corr + self.mdrizsky
err_data = self['ERR']*exptime_corr
hdu.append(pyfits.ImageHDU(data=sci_data, header=h,
name='SCI'))
hdu.append(pyfits.ImageHDU(data=err_data, header=h,
name='ERR'))
hdu.append(pyfits.ImageHDU(data=self.data['DQ'], header=h, name='DQ'))
if 'MED' in self.data:
hdu.append(pyfits.ImageHDU(data=self.data['MED'],
header=h, name='MED'))
if 'BKG' in self.data:
hdu.append(pyfits.ImageHDU(data=self.data['BKG'],
header=h, name='BKG'))
if self.data['REF'] is not None:
h['PHOTFLAM'] = self.ref_photflam
h['PHOTPLAM'] = self.ref_photplam
h['FILTER'] = self.ref_filter
h['REF_FILE'] = self.ref_file
hdu.append(pyfits.ImageHDU(data=self.data['REF'], header=h,
name='REF'))
hdul = pyfits.HDUList(hdu)
return hdul
def __getitem__(self, ext):
if self.data[ext] is None:
return None
if ext == 'REF':
return self.data['REF']/self.ref_photflam
elif ext == 'DQ':
return self.data['DQ']
else:
return self.data[ext]/self.photflam
[docs] def get_common_slices(self, other, verify_parent=True):
"""
Get slices of overlaps between two `ImageData` objects
"""
if verify_parent:
if self.parent_file != other.parent_file:
msg = ('Parent expodures don\'t match!\n' +
' self: {0}\n'.format(self.parent_file) +
' other: {0}\n'.format(other.parent_file))
raise IOError(msg)
ll = np.min([self.origin, other.origin], axis=0)
ur = np.max([self.origin+self.sh, other.origin+other.sh], axis=0)
# other in self
lls = np.minimum(other.origin - ll, self.sh)
urs = np.clip(other.origin + self.sh - self.origin, [0, 0], self.sh)
# self in other
llo = np.minimum(self.origin - ll, other.sh)
uro = np.clip(self.origin + other.sh - other.origin, [0, 0], other.sh)
self_slice = (slice(lls[0], urs[0]), slice(lls[1], urs[1]))
other_slice = (slice(llo[0], uro[0]), slice(llo[1], uro[1]))
return self_slice, other_slice
[docs]class GrismFLT(object):
"""Scripts for modeling of individual grism FLT images"""
def __init__(self, grism_file='', sci_extn=1, direct_file='',
pad=(64,256), ref_file=None, ref_ext=0, seg_file=None,
shrink_segimage=True, force_grism='G141', verbose=True,
process_jwst_header=True, use_jwst_crds=False):
"""Read FLT files and, optionally, reference/segmentation images.
Parameters
----------
grism_file : str
Grism image (optional).
Empty string or filename of a FITS file that must contain
extensions ('SCI', `sci_extn`), ('ERR', `sci_extn`), and
('DQ', `sci_extn`). For example, a WFC3/IR "FLT" FITS file.
sci_extn : int
EXTNAME of the file to consider. For WFC3/IR this can only be
1. For ACS and WFC3/UVIS, this can be 1 or 2 to specify the two
chips.
direct_file : str
Direct image (optional).
Empty string or filename of a FITS file that must contain
extensions ('SCI', `sci_extn`), ('ERR', `sci_extn`), and
('DQ', `sci_extn`). For example, a WFC3/IR "FLT" FITS file.
pad : int, int
Padding to add around the periphery of the images to allow
modeling of dispersed spectra for objects that could otherwise
fall off of the direct image itself. Modeling them requires an
external reference image (`ref_file`) that covers an area larger
than the individual direct image itself (e.g., a mosaic of a
survey field).
For WFC3/IR spectra, the first order spectra reach 248 and 195
pixels for G102 and G141, respectively, and `pad` could be set
accordingly if the reference image is large enough.
ref_file : str or `~astropy.io.fits.ImageHDU`/`~astropy.io.fits.PrimaryHDU`
Image mosaic to use as the reference image in place of the direct
image itself. For example, this could be the deeper image
drizzled from all direct images taken within a single visit or it
could be a much deeper/wider image taken separately in perhaps
even a different filter.
.. note::
Assumes that the WCS are aligned between `grism_file`,
`direct_file` and `ref_file`!
ref_ext : int
FITS extension to use if `ref_file` is a filename string.
seg_file : str or `~astropy.io.fits.ImageHDU`/`~astropy.io.fits.PrimaryHDU`
Segmentation image mosaic to associate pixels with discrete
objects. This would typically be generated from a rectified
image like `ref_file`, though here it is not required that
`ref_file` and `seg_file` have the same image dimensions but
rather just that the WCS are aligned between them.
shrink_segimage : bool
Try to make a smaller cutout of the reference images to speed
up blotting and array copying. This is most helpful for very
large input mosaics.
force_grism : str
Use this grism in "simulation mode" where only `direct_file` is
specified.
verbose : bool
Print status messages to the terminal.
Attributes
----------
grism, direct : `ImageData`
Grism and direct image data and parameters
conf : `~grizli.grismconf.aXeConf`
Grism configuration object.
seg : array-like
Segmentation image array.
model : array-like
Model of the grism exposure with the same dimensions as the
full detector array.
object_dispersers : dict
Container for storing information about what objects have been
added to the model of the grism exposure
catalog : `~astropy.table.Table`
Associated photometric catalog. Not required.
"""
import stwcs.wcsutil
# Read files
self.grism_file = grism_file
_GRISM_OPEN = False
if os.path.exists(grism_file):
grism_im = pyfits.open(grism_file)
_GRISM_OPEN = True
if grism_im[0].header['INSTRUME'] == 'ACS':
wcs = stwcs.wcsutil.HSTWCS(grism_im, ext=('SCI', sci_extn))
else:
wcs = None
self.grism = ImageData(hdulist=grism_im, sci_extn=sci_extn,
wcs=wcs,
process_jwst_header=process_jwst_header)
else:
if (grism_file is None) | (grism_file == ''):
self.grism = None
else:
print('\nFile not found: {0}!\n'.format(grism_file))
raise IOError
self.direct_file = direct_file
_DIRECT_OPEN = False
if os.path.exists(direct_file):
direct_im = pyfits.open(direct_file)
_DIRECT_OPEN = True
if direct_im[0].header['INSTRUME'] == 'ACS':
wcs = stwcs.wcsutil.HSTWCS(direct_im, ext=('SCI', sci_extn))
else:
wcs = None
self.direct = ImageData(hdulist=direct_im, sci_extn=sci_extn,
wcs=wcs,
process_jwst_header=process_jwst_header)
else:
if (direct_file is None) | (direct_file == ''):
self.direct = None
else:
print('\nFile not found: {0}!\n'.format(direct_file))
raise IOError
### Simulation mode, no grism exposure
if isinstance(pad, int):
self.pad = [pad, pad]
else:
self.pad = pad
if self.grism is not None:
if np.max(self.grism.pad) > 0:
self.pad = self.grism.pad
if (self.grism is None) & (self.direct is not None):
self.grism = ImageData(hdulist=direct_im, sci_extn=sci_extn)
self.grism_file = self.direct_file
self.grism.filter = force_grism
# Grism exposure only, assumes will get reference from ref_file
if (self.direct is None) & (self.grism is not None):
self.direct = ImageData(hdulist=grism_im, sci_extn=sci_extn)
self.direct_file = self.grism_file
# Add padding
if self.direct is not None:
if np.max(self.pad) > 0:
self.direct.add_padding(self.pad)
self.direct.unset_dq()
nbad = self.direct.flag_negative(sigma=-3)
self.direct.data['SCI'] *= (self.direct.data['DQ'] == 0)
self.direct.data['SCI'] *= (self.direct.data['ERR'] > 0)
if self.grism is not None:
if np.max(self.pad) > 0:
self.grism.add_padding(self.pad)
self.pad = self.grism.pad
self.grism.unset_dq()
nbad = self.grism.flag_negative(sigma=-3)
self.grism.data['SCI'] *= (self.grism.data['DQ'] == 0)
self.grism.data['SCI'] *= (self.grism.data['ERR'] > 0)
# Load data from saved model files, if available
# if os.path.exists('%s_model.fits' %(self.grism_file)):
# pass
# Holder for the full grism model array
self.model = np.zeros_like(self.direct.data['SCI'])
# Grism configuration
if self.grism.instrument in ['NIRCAM', 'NIRISS']:
direct_filter = self.grism.pupil
elif 'DFILTER' in self.grism.header:
direct_filter = self.grism.header['DFILTER']
else:
direct_filter = self.direct.filter
conf_args = dict(instrume=self.grism.instrument,
filter=direct_filter,
grism=self.grism.filter,
pupil=self.grism.pupil,
module=self.grism.module,
chip=self.grism.ccdchip,
use_jwst_crds=use_jwst_crds)
if 'CONFFILE' in self.grism.header:
self.conf_file = self.grism.header['CONFFILE']
else:
self.conf_file = grismconf.get_config_filename(**conf_args)
self.grism.header['CONFFILE'] = self.conf_file
self.conf = grismconf.load_grism_config(self.conf_file)
self.object_dispersers = OrderedDict()
# Blot reference image
self.process_ref_file(ref_file, ref_ext=ref_ext,
shrink_segimage=shrink_segimage,
verbose=verbose)
# Blot segmentation image
self.process_seg_file(seg_file, shrink_segimage=shrink_segimage,
verbose=verbose)
# End things
self.get_dispersion_PA()
self.catalog = None
self.catalog_file = None
self.is_rotated = False
self.has_edge_mask = False
# Cleanup
if _GRISM_OPEN:
grism_im.close()
if _DIRECT_OPEN:
direct_im.close()
[docs] def process_ref_file(self, ref_file, ref_ext=0, shrink_segimage=True,
verbose=True):
"""Read and blot a reference image
Parameters
----------
ref_file : str or `~astropy.fits.io.ImageHDU` / `~astropy.fits.io.PrimaryHDU`
Filename or `astropy.io.fits` Image HDU of the reference image.
shrink_segimage : bool
Try to make a smaller cutout of the reference image to speed
up blotting and array copying. This is most helpful for very
large input mosaics.
verbose : bool
Print some status information to the terminal
Returns
-------
status : bool
False if `ref_file` is None. True if completes successfully.
The blotted reference image is stored in the array attribute
`self.direct.data['REF']`.
The `ref_filter` attribute is determined from the image header and the
`ref_photflam` scaling is taken either from the header if possible, or
the global `photflam` variable defined at the top of this file.
"""
if ref_file is None:
return False
if (isinstance(ref_file, pyfits.ImageHDU) |
isinstance(ref_file, pyfits.PrimaryHDU)):
self.ref_file = ref_file.fileinfo()['file'].name
ref_str = ''
ref_hdu = ref_file
_IS_OPEN = False
else:
self.ref_file = ref_file
ref_str = '{0}[0]'.format(self.ref_file)
_IS_OPEN = True
ref_im = pyfits.open(ref_file, load_lazy_hdus=False)
ref_hdu = ref_im[ref_ext]
refh = ref_hdu.header
if shrink_segimage:
ref_hdu = self.direct.shrink_large_hdu(ref_hdu,
extra=np.max(self.pad),
verbose=True)
if verbose:
msg = '{0} / blot reference {1}'
print(msg.format(self.direct_file, ref_str))
blotted_ref = self.grism.blot_from_hdu(hdu=ref_hdu,
segmentation=False, interp='poly5')
header_values = {}
self.direct.ref_filter = utils.parse_filter_from_header(refh)
self.direct.ref_file = ref_str
key_list = {'PHOTFLAM': photflam_list, 'PHOTPLAM': photplam_list}
for key in ['PHOTFLAM', 'PHOTPLAM']:
if key in refh:
try:
header_values[key] = ref_hdu.header[key]*1.
except TypeError:
msg = 'Problem processing header keyword {0}: ** {1} **'
print(msg.format(key, ref_hdu.header[key]))
raise TypeError
else:
filt = self.direct.ref_filter
if filt in key_list[key]:
header_values[key] = key_list[key][filt]
else:
msg = 'Filter "{0}" not found in {1} tabulated list'
print(msg.format(filt, key))
raise IndexError
# Found keywords
self.direct.ref_photflam = header_values['PHOTFLAM']
self.direct.ref_photplam = header_values['PHOTPLAM']
# TBD: compute something like a cross-correlation offset
# between blotted reference and the direct image itself
self.direct.data['REF'] = np.asarray(blotted_ref,dtype=np.float32)
# print self.direct.data['REF'].shape, self.direct.ref_photflam
self.direct.data['REF'] *= self.direct.ref_photflam
# Fill empty pixels in the reference image from the SCI image,
# but don't do it if direct['SCI'] is just a copy from the grism
# if not self.direct.filter.startswith('G'):
# empty = self.direct.data['REF'] == 0
# self.direct.data['REF'][empty] += self.direct['SCI'][empty]
# self.direct.data['ERR'] *= 0.
# self.direct.data['DQ'] *= 0
self.direct.ABZP = (0*np.log10(self.direct.ref_photflam) - 21.10 -
5*np.log10(self.direct.ref_photplam) + 18.6921)
self.direct.thumb_extension = 'REF'
if _IS_OPEN:
ref_im.close()
# refh['FILTER'].upper()
return True
[docs] def process_seg_file(self, seg_file, shrink_segimage=True, verbose=True):
"""Read and blot a rectified segmentation image
Parameters
----------
seg_file : str or `~astropy.fits.io.ImageHDU` / `~astropy.fits.io.PrimaryHDU`
Filename or `astropy.io.fits` Image HDU of the segmentation image.
shrink_segimage : bool
Try to make a smaller cutout of the segmentation image to speed
up blotting and array copying. This is most helpful for very
large input mosaics.
verbose : bool
Print some status information to the terminal
Returns
-------
The blotted segmentation image is stored in the attribute `GrismFLT.seg`.
"""
if seg_file is not None:
if (isinstance(seg_file, pyfits.ImageHDU) |
isinstance(seg_file, pyfits.PrimaryHDU)):
self.seg_file = ''
seg_str = ''
seg_hdu = seg_file
segh = seg_hdu.header
_IS_OPEN = False
else:
self.seg_file = seg_file
seg_str = '{0}[0]'.format(self.seg_file)
seg_im = pyfits.open(seg_file)
seg_hdu = seg_im[0]
_IS_OPEN = True
if shrink_segimage:
seg_hdu = self.direct.shrink_large_hdu(seg_hdu,
extra=np.max(self.pad),
verbose=True)
# Make sure image big enough
seg_hdu = self.direct.expand_hdu(seg_hdu)
if verbose:
msg = '{0} / blot segmentation {1}'
print(msg.format(self.direct_file, seg_str))
blotted_seg = self.grism.blot_from_hdu(hdu=seg_hdu,
segmentation=True, grow=3,
interp='poly5')
self.seg = blotted_seg
if _IS_OPEN:
seg_im.close()
else:
self.seg = np.zeros(self.direct.sh, dtype=np.float32)
[docs] def get_dispersion_PA(self, decimals=0):
"""Compute exact PA of the dispersion axis, including tilt of the
trace and the FLT WCS
Parameters
----------
decimals : int or None
Number of decimal places to round to, passed to `~numpy.round`.
If None, then don't round.
Returns
-------
dispersion_PA : float
PA (angle East of North) of the dispersion axis.
"""
from astropy.coordinates import Angle
import astropy.units as u
# extra tilt of the 1st order grism spectra
if 'BEAMA' in self.conf.conf_dict:
x0 = self.conf.conf_dict['BEAMA']
else:
x0 = np.array([10,30])
dy_trace, lam_trace = self.conf.get_beam_trace(x=507, y=507, dx=x0,
beam='A')
extra = np.arctan2(dy_trace[1]-dy_trace[0], x0[1]-x0[0])/np.pi*180
# Distorted WCS
crpix = self.direct.wcs.wcs.crpix
xref = [crpix[0], crpix[0]+1]
yref = [crpix[1], crpix[1]]
r, d = self.direct.wcs.all_pix2world(xref, yref, 1)
pa = Angle((extra +
np.arctan2(np.diff(r)*np.cos(d[0]/180*np.pi),
np.diff(d))[0]/np.pi*180)*u.deg)
dispersion_PA = pa.wrap_at(360*u.deg).value
if decimals is not None:
dispersion_PA = np.round(dispersion_PA, decimals=decimals)
self.dispersion_PA = dispersion_PA
return float(dispersion_PA)
[docs] def compute_model_orders(self,
id=0,
x=None,
y=None,
size=10,
mag=-1,
spectrum_1d=None,
is_cgs=False,
compute_size=False,
max_size=None,
min_size=26,
store=True,
in_place=True,
get_beams=None,
psf_params=None,
verbose=True):
"""Compute dispersed spectrum for a given object id
Parameters
----------
id : int
Object ID number to match in the segmentation image
x, y : float
Center of the cutout to extract
size : int
Radius of the cutout to extract. The cutout is equivalent to
>>> xc, yc = int(x), int(y)
>>> thumb = self.direct.data['SCI'][yc-size:yc+size, xc-size:xc+size]
mag : float
Specified object magnitude, which will be compared to the
"MMAG_EXTRACT_[BEAM]" parameters in `self.conf` to decide if the
object is bright enough to compute the higher spectral orders.
Default of -1 means compute all orders listed in `self.conf.beams`
spectrum_1d : None or [`~numpy.array`, `~numpy.array`]
Template 1D spectrum to convolve with the grism disperser. If
None, assumes trivial spectrum flat in f_lambda flux densities.
Otherwise, the template is taken to be
>>> wavelength, flux = spectrum_1d
is_cgs : bool
Flux units of `spectrum_1d[1]` are cgs f_lambda flux densities,
rather than normalized in the detection band.
compute_size : bool
Ignore `x`, `y`, and `size` and compute the extent of the
segmentation polygon directly using
`~grizli.utils_numba.disperse.compute_segmentation_limits`.
max_size : int or None
Enforce a maximum size of the cutout when using `compute_size`.
store : bool
If True, then store the computed beams in the OrderedDict
`self.object_dispersers[id]`.
If many objects are computed, this can be memory intensive. To
save memory, set to False and then the function just stores the
input template spectrum (`spectrum_1d`) and the beams will have
to be recomputed if necessary.
in_place : bool
If True, add the computed spectral orders into `self.model`.
Otherwise, make a clean array with only the orders of the given
object.
get_beams : list or None
Spectral orders to retrieve with names as defined in the
configuration files, e.g., ['A'] generally for the +1st order of
HST grisms. If `None`, then get all orders listed in the
`beams` attribute of the `~grizli.grismconf.aXeConf`
configuration object.
psf_params : list
Optional parameters for generating an `~grizli.utils.EffectivePSF`
object for the spatial morphology.
Returns
-------
output : bool or `numpy.array`
If `in_place` is True, return status of True if everything goes
OK. The computed spectral orders are stored in place in
`self.model`.
Returns False if the specified `id` is not found in the
segmentation array independent of `in_place`.
If `in_place` is False, return a full array including the model
for the single object.
"""
from .utils_numba import disperse
if id in self.object_dispersers:
object_in_model = True
beams = self.object_dispersers[id]
out = self.object_dispersers[id]
# Handle pre 0.3.0-7 formats
if len(out) == 3:
old_cgs, old_spectrum_1d, beams = out
else:
old_cgs, old_spectrum_1d = out
beams = None
else:
object_in_model = False
beams = None
if self.direct.data['REF'] is None:
ext = 'SCI'
else:
ext = 'REF'
# set up the beams to extract
if get_beams is None:
beam_names = self.conf.beams
else:
beam_names = get_beams
# Did we initialize the PSF model this call?
INIT_PSF_NOW = False
# Do we need to compute the dispersed beams?
if beams is None:
# Use catalog
xcat = ycat = None
if self.catalog is not None:
ix = self.catalog['id'] == id
if ix.sum() == 0:
if verbose:
print(f'ID {id} not found in segmentation image')
return False
if hasattr(self.catalog['x_flt'][ix][0], 'unit'):
xcat = self.catalog['x_flt'][ix][0].value - 1
ycat = self.catalog['y_flt'][ix][0].value - 1
else:
xcat = self.catalog['x_flt'][ix][0] - 1
ycat = self.catalog['y_flt'][ix][0] - 1
# print '!!! X, Y: ', xcat, ycat, self.direct.origin, size
# use x, y if defined
if x is not None:
xcat = x
if y is not None:
ycat = y
if (compute_size) | (x is None) | (y is None) | (size is None):
# Get the array indices of the segmentation region
out = disperse.compute_segmentation_limits(self.seg, id,
self.direct.data[ext],
self.direct.sh)
ymin, ymax, y, xmin, xmax, x, area, segm_flux = out
if (area == 0) | ~np.isfinite(x) | ~np.isfinite(y):
if verbose:
print('ID {0:d} not found in segmentation image'.format(id))
return False
# Object won't disperse spectrum onto the grism image
if ((ymax < self.pad[0]-5) |
(ymin > self.direct.sh[0]-self.pad[0]+5) |
(ymin == 0) |
(ymax == self.direct.sh[0]) |
(xmin == 0) |
(xmax == self.direct.sh[1])):
return True
if compute_size:
try:
size = int(np.ceil(np.max([x-xmin, xmax-x,
y-ymin, ymax-y])))
except ValueError:
return False
size += 4
# Enforce minimum size
# size = np.maximum(size, 16)
size = np.maximum(size, min_size)
# To do: enforce a larger minimum cutout size for grisms
# that need it, e.g., UVIS/G280L
# maximum size
if max_size is not None:
size = np.min([size, max_size])
# Avoid problems at the array edges
size = np.min([size, int(x)-2, int(y)-2])
if (size < 4):
return True
# Thumbnails
# print '!! X, Y: ', x, y, self.direct.origin, size
if xcat is not None:
xc, yc = int(np.round(xcat))+1, int(np.round(ycat))+1
xcenter = (xcat-(xc-1))
ycenter = (ycat-(yc-1))
else:
xc, yc = int(np.round(x))+1, int(np.round(y))+1
xcenter = (x-(xc-1))
ycenter = (y-(yc-1))
origin = [yc-size + self.direct.origin[0],
xc-size + self.direct.origin[1]]
thumb = self.direct.data[ext][yc-size:yc+size, xc-size:xc+size]
seg_thumb = self.seg[yc-size:yc+size, xc-size:xc+size]
# Test that the id is actually in the thumbnail
test = disperse.compute_segmentation_limits(seg_thumb, id, thumb,
np.array(thumb.shape))
if test[-2] == 0:
if verbose:
print(f'ID {id} not found in segmentation image')
return False
# # Get precomputed dispersers
# beams, old_spectrum_1d, old_cgs = None, None, False
# if object_in_model:
# out = self.object_dispersers[id]
#
# # Handle pre 0.3.0-7 formats
# if len(out) == 3:
# old_cgs, old_spectrum_1d, old_beams = out
# else:
# old_cgs, old_spectrum_1d = out
# old_beams = None
#
# # Pull out just the requested beams
# if old_beams is not None:
# beams = OrderedDict()
# for b in beam_names:
# beams[b] = old_beams[b]
#
# if beams is None:
# Compute spectral orders ("beams")
beams = OrderedDict()
for b in beam_names:
# Only compute order if bright enough
if mag > self.conf.conf_dict['MMAG_EXTRACT_{0}'.format(b)]:
continue
# if 1:
# beam = GrismDisperser(id=id,
# direct=thumb,
# segmentation=seg_thumb,
# xcenter=xcenter,
# ycenter=ycenter,
# origin=origin,
# pad=self.pad,
# grow=self.grism.grow,
# beam=b,
# conf=self.conf,
# fwcpos=self.grism.fwcpos,
# MW_EBV=self.grism.MW_EBV)
try:
beam = GrismDisperser(id=id,
direct=thumb,
segmentation=seg_thumb,
xcenter=xcenter,
ycenter=ycenter,
origin=origin,
pad=self.pad,
grow=self.grism.grow,
beam=b,
conf=self.conf,
fwcpos=self.grism.fwcpos,
MW_EBV=self.grism.MW_EBV)
except:
utils.log_exception(utils.LOGFILE, traceback)
continue
# Set PSF model if necessary
if psf_params is not None:
store = True
INIT_PSF_NOW = True
if self.direct.ref_filter is None:
psf_filter = self.direct.filter
else:
psf_filter = self.direct.ref_filter
beam.x_init_epsf(flat_sensitivity=False,
psf_params=psf_params,
psf_filter=psf_filter, yoff=0.)
beams[b] = beam
# Compute old model
if object_in_model:
for b in beams:
beam = beams[b]
if hasattr(beam, 'psf') & (not INIT_PSF_NOW):
store = True
beam.compute_model_psf(spectrum_1d=old_spectrum_1d,
is_cgs=old_cgs)
else:
beam.compute_model(spectrum_1d=old_spectrum_1d,
is_cgs=old_cgs)
if get_beams:
out_beams = OrderedDict()
for b in beam_names:
out_beams[b] = beams[b]
return out_beams
if in_place:
# Update the internal model attribute
output = self.model
if store:
# Save the computed beams
self.object_dispersers[id] = is_cgs, spectrum_1d, beams
else:
# Just save the model spectrum (or empty spectrum)
self.object_dispersers[id] = is_cgs, spectrum_1d, None
else:
# Create a fresh array
output = np.zeros_like(self.model)
# if in_place:
# ### Update the internal model attribute
# output = self.model
# else:
# ### Create a fresh array
# output = np.zeros_like(self.model)
# Set PSF model if necessary
if psf_params is not None:
if self.direct.ref_filter is None:
psf_filter = self.direct.filter
else:
psf_filter = self.direct.ref_filter
# Loop through orders and add to the full model array, in-place or
# a separate image
for b in beams:
beam = beams[b]
# Subtract previously-added model
if object_in_model & in_place:
beam.add_to_full_image(-beam.model, output)
# Update PSF params
# if psf_params is not None:
# skip_init_psf = False
# if hasattr(beam, 'psf_params'):
# skip_init_psf |= np.prod(np.isclose(beam.psf_params, psf_params)) > 0
#
# if not skip_init_psf:
# beam.x_init_epsf(flat_sensitivity=False, psf_params=psf_params, psf_filter=psf_filter, yoff=0.06)
# Compute model
if hasattr(beam, 'psf'):
beam.compute_model_psf(spectrum_1d=spectrum_1d, is_cgs=is_cgs)
else:
beam.compute_model(spectrum_1d=spectrum_1d, is_cgs=is_cgs)
# Add in new model
beam.add_to_full_image(beam.model, output)
if in_place:
return True
else:
return beams, output
[docs] def compute_full_model(self, ids=None, mags=None, mag_limit=22, store=True, verbose=False, size=10, min_size=26, compute_size=True):
"""Compute flat-spectrum model for multiple objects.
Parameters
----------
ids : None, list, or `~numpy.array`
id numbers to compute in the model. If None then take all ids
from unique values in `self.seg`.
mags : None, float, or list / `~numpy.array`
magnitudes corresponding to list if `ids`. If None, then compute
magnitudes based on the flux in segmentation regions and
zeropoints determined from PHOTFLAM and PHOTPLAM.
size, compute_size : int, bool
Sizes of individual cutouts, see
`~grizli.model.GrismFLT.compute_model_orders`.
Returns
-------
Updated model stored in `self.model` attribute.
"""
try:
from tqdm import tqdm
has_tqdm = True
except:
has_tqdm = False
print('(`pip install tqdm` for a better verbose iterator)')
from .utils_numba import disperse
if ids is None:
ids = np.unique(self.seg)[1:]
# If `mags` array not specified, compute magnitudes within
# segmentation regions.
if mags is None:
if verbose:
print('Compute IDs/mags')
mags = np.zeros(len(ids))
for i, id in enumerate(ids):
out = disperse.compute_segmentation_limits(self.seg, id,
self.direct.data[self.direct.thumb_extension],
self.direct.sh)
ymin, ymax, y, xmin, xmax, x, area, segm_flux = out
mags[i] = self.direct.ABZP - 2.5*np.log10(segm_flux)
ix = mags < mag_limit
ids = ids[ix]
mags = mags[ix]
else:
if np.isscalar(mags):
mags = [mags for i in range(len(ids))]
else:
if len(ids) != len(mags):
raise ValueError('`ids` and `mags` lists different sizes')
# Now compute the full model
if verbose & has_tqdm:
iterator = tqdm(zip(ids, mags))
else:
iterator = zip(ids, mags)
for id_i, mag_i in iterator:
self.compute_model_orders(id=id_i,
compute_size=compute_size,
mag=mag_i,
size=size,
in_place=True,
store=store,
min_size=min_size,
)
[docs] def smooth_mask(self, gaussian_width=4, threshold=2.5):
"""Compute a mask where smoothed residuals greater than some value
Perhaps useful for flagging contaminated pixels that aren't in the
model, such as high orders dispersed from objects that fall off of the
direct image, but this hasn't yet been extensively tested.
Parameters
----------
gaussian_width : float
Width of the Gaussian filter used with
`~scipy.ndimage.gaussian_filter`.
threshold : float
Threshold, in sigma, above which to flag residuals.
Returns
-------
Nothing, but pixels are masked in `self.grism.data['SCI']`.
"""
import scipy.ndimage as nd
mask = self.grism['SCI'] != 0
resid = (self.grism['SCI'] - self.model)*mask
sm = nd.gaussian_filter(np.abs(resid), gaussian_width)
resid_mask = (np.abs(sm) > threshold*self.grism['ERR'])
self.grism.data['SCI'][resid_mask] = 0
[docs] def blot_catalog(self, input_catalog, columns=['id', 'ra', 'dec'],
sextractor=False, ds9=None):
"""Compute detector-frame coordinates of sky positions in a catalog.
Parameters
----------
input_catalog : `~astropy.table.Table`
Full catalog with sky coordinates. Can be SExtractor or other.
columns : [str,str,str]
List of columns that specify the object id, R.A. and Decl. For
catalogs created with SExtractor this might be
['NUMBER', 'X_WORLD', 'Y_WORLD'].
Detector coordinates will be computed with
`self.direct.wcs.all_world2pix` with `origin=1`.
ds9 : `~grizli.ds9.DS9`, optional
If provided, load circular regions at the derived detector
coordinates.
Returns
-------
catalog : `~astropy.table.Table`
New catalog with columns 'x_flt' and 'y_flt' of the detector
coordinates. Also will copy the `columns` names to columns with
names 'id','ra', and 'dec' if necessary, e.g., for SExtractor
catalogs.
"""
from astropy.table import Column
if sextractor:
columns = ['NUMBER', 'X_WORLD', 'Y_WORLD']
# Detector coordinates. N.B.: 1 indexed!
xy = self.direct.wcs.all_world2pix(input_catalog[columns[1]],
input_catalog[columns[2]], 1,
tolerance=-4,
quiet=True)
# Objects with positions within the image
sh = self.direct.sh
keep = ((xy[0] > 0) & (xy[0] < sh[1]) &
(xy[1] > (self.pad[0]-5)) & (xy[1] < (sh[0]-self.pad[0]+5)))
catalog = input_catalog[keep]
# Remove columns if they exist
for col in ['x_flt', 'y_flt']:
if col in catalog.colnames:
catalog.remove_column(col)
# Columns with detector coordinates
catalog.add_column(Column(name='x_flt', data=xy[0][keep]))
catalog.add_column(Column(name='y_flt', data=xy[1][keep]))
# Copy standardized column names if necessary
if ('id' not in catalog.colnames):
catalog.add_column(Column(name='id', data=catalog[columns[0]]))
if ('ra' not in catalog.colnames):
catalog.add_column(Column(name='ra', data=catalog[columns[1]]))
if ('dec' not in catalog.colnames):
catalog.add_column(Column(name='dec', data=catalog[columns[2]]))
# Show positions in ds9
if ds9:
for i in range(len(catalog)):
x_flt, y_flt = catalog['x_flt'][i], catalog['y_flt'][i]
reg = 'circle {0:f} {1:f} 5\n'.format(x_flt, y_flt)
ds9.set('regions', reg)
return catalog
[docs] def photutils_detection(self, use_seg=False, data_ext='SCI',
detect_thresh=2., grow_seg=5, gauss_fwhm=2.,
verbose=True, save_detection=False, ZP=None):
"""Use photutils to detect objects and make segmentation map
Parameters
----------
detect_thresh : float
Detection threshold, in sigma
grow_seg : int
Number of pixels to grow around the perimeter of detected objects
witha maximum filter
gauss_fwhm : float
FWHM of Gaussian convolution kernel that smoothes the detection
image.
verbose : bool
Print logging information to the terminal
save_detection : bool
Save the detection images and catalogs
ZP : float or None
AB magnitude zeropoint of the science array. If `None` then, try
to compute based on PHOTFLAM and PHOTPLAM values and use zero if
that fails.
Returns
-------
status : bool
True if completed successfully. False if `data_ext=='REF'` but
no reference image found.
Stores an astropy.table.Table object to `self.catalog` and a
segmentation array to `self.seg`.
"""
if ZP is None:
if ((self.direct.filter in photflam_list.keys()) &
(self.direct.filter in photplam_list.keys())):
# ABMAG_ZEROPOINT from
# http://www.stsci.edu/hst/wfc3/phot_zp_lbn
ZP = (-2.5*np.log10(photflam_list[self.direct.filter]) -
21.10 - 5*np.log10(photplam_list[self.direct.filter]) +
18.6921)
else:
ZP = 0.
if use_seg:
seg = self.seg
else:
seg = None
if self.direct.data['ERR'].max() != 0.:
err = self.direct.data['ERR']/self.direct.photflam
else:
err = None
if (data_ext == 'REF'):
if (self.direct.data['REF'] is not None):
err = None
else:
print('No reference data found for `self.direct.data[\'REF\']`')
return False
go_detect = utils.detect_with_photutils
cat, seg = go_detect(self.direct.data[data_ext]/self.direct.photflam,
err=err, dq=self.direct.data['DQ'], seg=seg,
detect_thresh=detect_thresh, npixels=8,
grow_seg=grow_seg, gauss_fwhm=gauss_fwhm,
gsize=3, wcs=self.direct.wcs,
save_detection=save_detection,
root=self.direct_file.split('.fits')[0],
background=None, gain=None, AB_zeropoint=ZP,
overwrite=True, verbose=verbose)
self.catalog = cat
self.catalog_file = '<photutils>'
self.seg = seg
return True
[docs] def load_photutils_detection(self, seg_file=None, seg_cat=None,
catalog_format='ascii.commented_header'):
"""
Load segmentation image and catalog, either from photutils
or SExtractor.
If SExtractor, use `catalog_format='ascii.sextractor'`.
"""
root = self.direct_file.split('.fits')[0]
if seg_file is None:
seg_file = root + '.detect_seg.fits'
if not os.path.exists(seg_file):
print('Segmentation image {0} not found'.format(seg_file))
return False
with pyfits.open(seg_file) as seg_im:
self.seg = seg_im[0].data.astype(np.float32)
if seg_cat is None:
seg_cat = root + '.detect.cat'
if not os.path.exists(seg_cat):
print('Segmentation catalog {0} not found'.format(seg_cat))
return False
self.catalog = Table.read(seg_cat, format=catalog_format)
self.catalog_file = seg_cat
[docs] def save_model(self, overwrite=True, verbose=True):
"""Save model properties to FITS file
"""
try:
import cPickle as pickle
except:
# Python 3
import pickle
root = self.grism_file.split('_flt.fits')[0].split('_rate.fits')[0]
root = root.split('_elec.fits')[0]
if isinstance(self.pad, int):
_pad = (self.pad, self.pad)
else:
_pad = self.pad
h = pyfits.Header()
h['GFILE'] = (self.grism_file, 'Grism exposure name')
h['GFILTER'] = (self.grism.filter, 'Grism spectral element')
h['INSTRUME'] = (self.grism.instrument, 'Instrument of grism file')
h['PADX'] = (_pad[1], 'Image padding used axis1')
h['PADY'] = (_pad[0], 'Image padding used axis2')
h['DFILE'] = (self.direct_file, 'Direct exposure name')
h['DFILTER'] = (self.direct.filter, 'Grism spectral element')
h['REF_FILE'] = (self.ref_file, 'Reference image')
h['SEG_FILE'] = (self.seg_file, 'Segmentation image')
h['CONFFILE'] = (self.conf_file, 'Configuration file')
h['DISP_PA'] = (self.dispersion_PA, 'Dispersion position angle')
h0 = pyfits.PrimaryHDU(header=h)
model = pyfits.ImageHDU(data=self.model, header=self.grism.header,
name='MODEL')
seg = pyfits.ImageHDU(data=self.seg, header=self.grism.header,
name='SEG')
hdu = pyfits.HDUList([h0, model, seg])
if 'REF' in self.direct.data:
ref_header = self.grism.header.copy()
ref_header['FILTER'] = self.direct.ref_filter
ref_header['PARENT'] = self.ref_file
ref_header['PHOTFLAM'] = self.direct.ref_photflam
ref_header['PHOTPLAM'] = self.direct.ref_photplam
ref = pyfits.ImageHDU(data=self.direct['REF'],
header=ref_header, name='REFERENCE')
hdu.append(ref)
hdu.writeto('{0}_model.fits'.format(root), overwrite=overwrite,
output_verify='fix')
fp = open('{0}_model.pkl'.format(root), 'wb')
pickle.dump(self.object_dispersers, fp)
fp.close()
if verbose:
print('Saved {0}_model.fits and {0}_model.pkl'.format(root))
[docs] def save_full_pickle(self, verbose=True):
"""Save entire `GrismFLT` object to a pickle
"""
try:
import cPickle as pickle
except:
# Python 3
import pickle
root = self.grism_file.split('_flt.fits')[0].split('_cmb.fits')[0]
root = root.split('_flc.fits')[0].split('_rate.fits')[0]
root = root.split('_elec.fits')[0]
if root == self.grism_file:
# unexpected extension, so just insert before '.fits'
root = self.grism_file.split('.fits')[0]
hdu = pyfits.HDUList([pyfits.PrimaryHDU()])
# Remove dummy extensions if REF found
skip_direct_extensions = []
if 'REF' in self.direct.data:
if self.direct.data['REF'] is not None:
skip_direct_extensions = ['SCI','ERR','DQ']
for key in self.direct.data.keys():
if key in skip_direct_extensions:
hdu.append(pyfits.ImageHDU(data=None,
header=self.direct.header,
name='D'+key))
else:
hdu.append(pyfits.ImageHDU(data=self.direct.data[key],
header=self.direct.header,
name='D'+key))
for key in self.grism.data.keys():
hdu.append(pyfits.ImageHDU(data=self.grism.data[key],
header=self.grism.header,
name='G'+key))
hdu.append(pyfits.ImageHDU(data=self.seg,
header=self.grism.header,
name='SEG'))
hdu.append(pyfits.ImageHDU(data=self.model,
header=self.grism.header,
name='MODEL'))
hdu.writeto('{0}.{1:02d}.GrismFLT.fits'.format(root, self.grism.sci_extn),
overwrite=True,
output_verify='fix')
# zero out large data objects
self.direct.data = self.grism.data = self.seg = self.model = None
# Don't store conf in pickle
if hasattr(self, 'conf'):
delattr(self, 'conf')
fp = open('{0}.{1:02d}.GrismFLT.pkl'.format(root,
self.grism.sci_extn), 'wb')
pickle.dump(self, fp)
fp.close()
self.save_wcs(overwrite=True, verbose=False)
# Reload conf
self.conf = grismconf.load_grism_config(self.conf_file)
[docs] def save_wcs(self, overwrite=True, verbose=True):
"""TBD
"""
if self.direct.parent_file == self.grism.parent_file:
base_list = [self.grism]
else:
base_list = [self.direct, self.grism]
for base in base_list:
hwcs = base.wcs.to_fits(relax=True)
hwcs[0].header['PADX'] = base.pad[1]
hwcs[0].header['PADY'] = base.pad[0]
if 'CCDCHIP' in base.header:
ext = {1: 2, 2: 1}[base.header['CCDCHIP']]
else:
ext = base.header['EXTVER']
wcsfile = base.parent_file.replace('.fits', f'.{ext:02d}.wcs.fits')
try:
hwcs.writeto(wcsfile, overwrite=overwrite)
except:
hwcs.writeto(wcsfile, clobber=overwrite)
if verbose:
print(wcsfile)
[docs] def load_from_fits(self, save_file):
"""Load saved data from a FITS file
Parameters
----------
save_file : str
Filename of the saved output
Returns
-------
True if completed successfully
"""
fits = pyfits.open(save_file)
self.seg = fits['SEG'].data*1
self.model = fits['MODEL'].data*1
self.direct.data = OrderedDict()
self.grism.data = OrderedDict()
for ext in range(1, len(fits)):
key = fits[ext].header['EXTNAME'][1:]
if fits[ext].header['EXTNAME'].startswith('D'):
if fits[ext].data is None:
self.direct.data[key] = None
else:
self.direct.data[key] = fits[ext].data*1
elif fits[ext].header['EXTNAME'].startswith('G'):
if fits[ext].data is None:
self.grism.data[key] = None
else:
self.grism.data[key] = fits[ext].data*1
else:
pass
fits.close()
del(fits)
return True
[docs] def apply_POM(self, warn_if_too_small=True, verbose=True):
"""
Apply pickoff mask to segmentation map to control sources that are dispersed onto the detector
"""
if not self.grism.instrument.startswith('NIRCAM'):
print('POM only defined for NIRCam')
return True
pom_path = os.path.join(GRIZLI_PATH,
f'CONF/GRISM_NIRCAM/V*/NIRCAM_LW_POM_Mod{self.grism.module}.fits')
pom_files = glob.glob(pom_path)
if len(pom_files) == 0:
print(f'Couldn\'t find POM reference files {pom_path}')
return False
pom_files.sort()
pom_file = pom_files[-1]
if verbose:
print(f'NIRCam: apply POM geometry from {pom_file}')
pom = pyfits.open(pom_file)[-1]
pomh = pom.header
if self.grism.pupil.lower() == 'grismc':
_warn = self.pad[0] < 790
_padix = 0
elif self.grism.pupil.lower() == 'grismr':
_warn = self.pad[1] < 790
_padix = 1
else:
_warn = False
if _warn & warn_if_too_small:
print(f'Warning: `pad[{_padix}]` should be > 790 for '
f'NIRCam/{self.grism.pupil} to catch '
'all out-of-field sources within the POM coverage.')
# Slice geometry
a_origin = np.array([-self.pad[0], -self.pad[1]])
a_shape = np.array(self.grism.sh)
b_origin = np.array([-pomh['NOMYSTRT'], -pomh['NOMXSTRT']])
b_shape = np.array(pom.data.shape)
self_sl, pom_sl = utils.get_common_slices(a_origin, a_shape,
b_origin, b_shape)
pom_data = self.seg*0
pom_data[self_sl] += pom.data[pom_sl]
self.pom_data = pom_data
self.seg *= (pom_data > 0)
return True
[docs] def mask_mosaic_edges(self, sky_poly=None, verbose=True, force=False, err_scale=10, dq_mask=False, dq_value=1024, resid_sn=7):
"""
Mask edges of exposures that might not have modeled spectra
"""
from regions import Regions
import scipy.ndimage as nd
if (self.has_edge_mask) & (force is False):
return True
if sky_poly is None:
return True
xy_image = self.grism.wcs.all_world2pix(np.array(sky_poly.boundary.xy).T, 0)
# Calculate edge for mask
#xedge = 100
x0 = 0
y0 = (self.grism.sh[0] - 2*self.pad[0])/2
dx = np.arange(500)
tr_y, tr_lam = self.conf.get_beam_trace(x0, y0, dx=dx, beam='A')
tr_sens = np.interp(tr_lam, self.conf.sens['A']['WAVELENGTH'],
self.conf.sens['A']['SENSITIVITY'],
left=0, right=0)
xedge = dx[tr_sens > tr_sens.max()*0.05].max()
xy_image[:, 0] += xedge
xy_str = 'image;polygon('+','.join(['{0:.1f}'.format(p + 1) for p in xy_image.flatten()])+')'
reg = Regions.parse(xy_str, format='ds9')[0]
mask = reg.to_mask().to_image(shape=self.grism.sh).astype(bool)
# Only mask large residuals
if resid_sn > 0:
resid_mask = (self.grism['SCI'] - self.model) > resid_sn*self.grism['ERR']
resid_mask = nd.binary_dilation(resid_mask, iterations=3)
mask &= resid_mask
if dq_mask:
self.grism.data['DQ'] |= dq_value*mask
if verbose:
print('# mask mosaic edges: {0} ({1}, {2} pix) DQ={3:.0f}'.format(self.grism.parent_file, self.grism.filter, xedge, dq_value))
else:
self.grism.data['ERR'][mask] *= err_scale
if verbose:
print('# mask mosaic edges: {0} ({1}, {2} pix) err_scale={3:.1f}'.format(self.grism.parent_file, self.grism.filter, xedge, err_scale))
self.has_edge_mask = True
[docs] def get_trace_region_from_sky(self, ra, dec, width=2):
"""
Make a region file for the trace in pixel coordinates given sky position
TBD
"""
return None
[docs] def old_make_edge_mask(self, scale=3, force=False):
"""Make a mask for the edge of the grism FoV that isn't covered by the direct image
Parameters
----------
scale : float
Scale factor to multiply to the mask before it's applied to the
`self.grism.data['ERR']` array.
force : bool
Force apply the mask even if `self.has_edge_mask` is set
indicating that the function has already been run.
Returns
-------
Nothing, updates `self.grism.data['ERR']` in place.
Sets `self.has_edge_mask = True`.
"""
import scipy.ndimage as nd
if (self.has_edge_mask) & (force is False):
return True
kern = (np.arange(self.conf.conf_dict['BEAMA'][1]) > self.conf.conf_dict['BEAMA'][0])*1.
kern /= kern.sum()
if self.direct['REF'] is not None:
mask = self.direct['REF'] == 0
else:
mask = self.direct['SCI'] == 0
full_mask = nd.convolve(mask*1., kern.reshape((1, -1)),
origin=(0, -kern.size//2+20))
self.grism.data['ERR'] *= np.exp(full_mask*scale)
self.has_edge_mask = True
[docs]class BeamCutout(object):
def __init__(self, flt=None, beam=None, conf=None,
get_slice_header=True, fits_file=None, scale=1.,
contam_sn_mask=[10, 3], min_mask=0.01, min_sens=0.08,
mask_resid=True, isJWST=False, restore_medfilt=False):
"""Cutout spectral object from the full frame.
Parameters
----------
flt : `GrismFLT`
Parent FLT frame.
beam : `GrismDisperser`
Object and spectral order to consider
conf : `.grismconf.aXeConf`
Pre-computed configuration file. If not specified will regenerate
based on header parameters, which might be necessary for
multiprocessing parallelization and pickling.
get_slice_header : bool
TBD
fits_file : None or str
Optional FITS file containing the beam information, rather than
reading directly from a `GrismFLT` object with the `flt` and
`beam` paremters. Load with `load_fits`.
contam_sn_mask : TBD
min_mask : float
Minimum factor relative to the maximum pixel value of the flat
f-lambda model where the 2D cutout data are considered good.
min_sens : float
Minimum sensitivity relative to the maximum for a given grism
above which pixels are included in the fit.
Attributes
----------
grism, direct : `ImageData` (sliced)
Cutouts of the grism and direct images.
beam : `GrismDisperser`
High-level tools for computing dispersed models of the object
mask : array-like (bool)
Basic mask where `grism` DQ > 0 | ERR == 0 | SCI == 0.
fit_mask, DoF : array-like, int
Additional mask, DoF is `fit_mask.sum()` representing the
effective degrees of freedom for chi-squared.
ivar : array-like
Inverse variance array, taken from `grism` 1/ERR^2
model, modelf : array-like
2D and flattened versions of the object model array
contam : array-like
Contamination model
scif : array_like
Flattened version of `grism['SCI'] - contam`.
flat_flam : array-like
Flattened version of the flat-flambda object model
poly_order : int
Order of the polynomial model
"""
self.background = 0.
self.module = None
if fits_file is not None:
self.load_fits(fits_file, conf, restore_medfilt=restore_medfilt)
else:
self.init_from_input(flt, beam, conf, get_slice_header)
self.beam.scale = scale
self._parse_params = {'contam_sn_mask':contam_sn_mask,
'min_mask':min_mask,
'min_sens':min_sens,
'mask_resid':mask_resid}
# self.contam_sn_mask = contam_sn_mask
# self.min_mask = min_mask
# self.min_sens = min_sens
# self.mask_resid = mask_resid
self._parse_from_data(isJWST=isJWST, **self._parse_params)
def _parse_from_data(self, contam_sn_mask=[10, 3], min_mask=0.01,
seg_ids=None, min_sens=0.08, mask_resid=True, isJWST=False):
"""
See parameter description for `~grizli.model.BeamCutout`.
"""
# bad pixels or problems with uncertainties
self.mask = ((self.grism.data['DQ'] > 0) |
(self.grism.data['ERR'] == 0) |
(self.grism.data['SCI'] == 0))
self.var = self.grism.data['ERR']**2
self.var[self.mask] = 1.e30
self.ivar = 1/self.var
self.ivar[self.mask] = 0
self.thumbs = {}
#self.compute_model = self.beam.compute_model
#self.model = self.beam.model
self.modelf = self.beam.modelf # .flatten()
self.model = self.beam.modelf.reshape(self.beam.sh_beam)
# Attributes
self.size = self.modelf.size
self.wave = self.beam.lam
self.sh = self.beam.sh_beam
# Initialize for fits
if seg_ids is None:
self.flat_flam = self.compute_model(in_place=False, is_cgs=True)
else:
for i, sid in enumerate(seg_ids):
flat_i = self.compute_model(id=sid, in_place=False,
is_cgs=True)
if i == 0:
self.flat_flam = flat_i
else:
self.flat_flam += flat_i
# OK data where the 2D model has non-zero flux
self.fit_mask = (~self.mask.flatten()) & (self.ivar.flatten() != 0)
try:
self.fit_mask &= (self.flat_flam > min_mask*self.flat_flam.max())
except ValueError:
print('config file: ', self.beam.conf.conf_file)
utils.log_exception(utils.LOGFILE, traceback)
#self.fit_mask &= (self.flat_flam > 3*self.contam.flatten())
# Apply minimum sensitivity mask
self.sens_mask = 1.
if min_sens > 0:
flux_min_sens = (self.beam.sensitivity <
min_sens*self.beam.sensitivity.max())*1.
if flux_min_sens.sum() > 0:
test_spec = [self.beam.lam, flux_min_sens]
if seg_ids is None:
flat_sens = self.compute_model(in_place=False,
is_cgs=True,
spectrum_1d=test_spec)
else:
for i, sid in enumerate(seg_ids):
f_i = self.compute_model(id=sid, in_place=False,
is_cgs=True, spectrum_1d=test_spec)
if i == 0:
flat_sens = f_i
else:
flat_sens += f_i
# self.sens_mask = flat_sens == 0
# Make mask along columns
is_masked = (flat_sens.reshape(self.sh) > 0).sum(axis=0)
self.sens_mask = (np.dot(np.ones((self.sh[0], 1)), is_masked[None, :]) == 0).flatten()
self.fit_mask &= self.sens_mask
# Flat versions of sci/ivar arrays
self.scif = (self.grism.data['SCI'] - self.contam).flatten()
self.ivarf = self.ivar.flatten()
self.wavef = np.dot(np.ones((self.sh[0], 1)), self.wave[None, :]).flatten()
# Mask large residuals where throughput is low
if mask_resid:
resid = np.abs(self.scif - self.flat_flam)*np.sqrt(self.ivarf)
bad_resid = (self.flat_flam < 0.05*self.flat_flam.max())
bad_resid &= (resid > 5)
self.bad_resid = bad_resid
self.fit_mask *= ~bad_resid
else:
self.bad_resid = np.zeros_like(self.fit_mask)
# Mask very contaminated
contam_mask = ((self.contam*np.sqrt(self.ivar) > contam_sn_mask[0]) &
(self.model*np.sqrt(self.ivar) < contam_sn_mask[1]))
#self.fit_mask *= ~contam_mask.flatten()
self.contam_mask = ~nd.maximum_filter(contam_mask, size=5).flatten()
self.poly_order = None
# self.init_poly_coeffs(poly_order=1)
[docs] def load_fits(self, file, conf=None, direct_extn=1, grism_extn=2, restore_medfilt=False):
"""Initialize from FITS file
Parameters
----------
file : str
FITS file to read (as output from `write_fits`).
Returns
-------
Loads attributes to `self`.
"""
if isinstance(file, str):
hdu = pyfits.open(file)
file_is_open = True
else:
file_is_open = False
hdu = file
self.direct = ImageData(hdulist=hdu, sci_extn=direct_extn,
restore_medfilt=restore_medfilt)
self.grism = ImageData(hdulist=hdu, sci_extn=grism_extn,
restore_medfilt=restore_medfilt)
self.contam = hdu['CONTAM'].data*1
try:
self.modelf = hdu['MODEL'].data.flatten().astype(np.float32)*1
except:
self.modelf = self.grism['SCI'].flatten().astype(np.float32)*0.
if ('REF', 1) in hdu:
direct = hdu['REF', 1].data*1
else:
direct = hdu['SCI', 1].data*1
h0 = hdu[0].header
# if 'DFILTER' in self.grism.header:
# direct_filter = self.grism.header['DFILTER']
# else:
# direct_filter = self.direct.filter
# #
if 'DFILTER' in self.grism.header:
direct_filter = self.grism.header['DFILTER']
if self.grism.instrument in ['NIRCAM', 'NIRISS']:
direct_filter = self.grism.pupil
else:
direct_filter = self.direct.filter
if conf is None:
conf_args = dict(instrume=self.grism.instrument,
filter=direct_filter,
grism=self.grism.filter,
module=self.grism.module,
chip=self.grism.ccdchip)
if 'CONFFILE' in self.grism.header:
self.conf_file = self.grism.header['CONFFILE']
else:
self.conf_file = grismconf.get_config_filename(**conf_args)
self.grism.header['CONFFILE'] = self.conf_file
conf = grismconf.load_grism_config(self.conf_file)
if 'GROW' in self.grism.header:
grow = self.grism.header['GROW']
else:
grow = 1
if 'MW_EBV' in h0:
self.grism.MW_EBV = h0['MW_EBV']
else:
self.grism.MW_EBV = 0
self.grism.fwcpos = h0['FWCPOS']
if (self.grism.fwcpos == 0) | (self.grism.fwcpos == ''):
self.grism.fwcpos = None
if 'TYOFFSET' in h0:
yoffset = h0['TYOFFSET']
else:
yoffset = 0.
if 'TXOFFSET' in h0:
xoffset = h0['TXOFFSET']
else:
xoffset = None
if ('PADX' in h0) & ('PADY' in h0):
_pad = [h0['PADY'], h0['PADX']]
elif ('PAD' in h0):
_pad = [h0['PAD'], h0['PAD']]
self.beam = GrismDisperser(id=h0['ID'], direct=direct,
segmentation=hdu['SEG'].data*1,
origin=self.direct.origin,
pad=_pad,
grow=grow, beam=h0['BEAM'],
xcenter=h0['XCENTER'],
ycenter=h0['YCENTER'],
conf=conf, fwcpos=self.grism.fwcpos,
MW_EBV=self.grism.MW_EBV,
yoffset=yoffset, xoffset=xoffset)
self.grism.parent_file = h0['GPARENT']
self.direct.parent_file = h0['DPARENT']
self.id = h0['ID']
self.modelf = self.beam.modelf
# Cleanup
if file_is_open:
hdu.close()
@property
def trace_table(self):
"""
Table of trace parameters. Trace is unit-indexed.
"""
dtype = np.float32
tab = utils.GTable()
tab.meta['CONFFILE'] = os.path.basename(self.beam.conf.conf_file)
tab['wavelength'] = np.asarray(self.beam.lam*u.Angstrom,dtype=dtype)
tab['trace'] = np.asarray(self.beam.ytrace + self.beam.sh_beam[0]/2 - self.beam.ycenter,dtype=dtype)
sens_units = u.erg/u.second/u.cm**2/u.Angstrom/(u.electron/u.second)
tab['sensitivity'] = np.asarray(self.beam.sensitivity*sens_units,dtype=dtype)
return tab
[docs] def write_fits(self, root='beam_', overwrite=True, strip=False, include_model=True, get_hdu=False, get_trace_table=True):
"""Write attributes and data to FITS file
Parameters
----------
root : str
Output filename will be
'{root}_{self.id}.{self.grism.filter}.{self.beam}.fits'
with `self.id` zero-padded with 5 digits.
overwrite : bool
Overwrite existing file.
strip : bool
Strip out extensions that aren't totally necessary for
regenerating the `ImageData` object. That is, strip out the
direct image `SCI`, `ERR`, and `DQ` extensions if `REF` is
defined. Also strip out `MODEL`.
get_hdu : bool
Return `~astropy.io.fits.HDUList` rather than writing a file.
Returns
-------
hdu : `~astropy.io.fits.HDUList`
If `get_hdu` is True
outfile : str
If `get_hdu` is False, return the output filename.
"""
h0 = pyfits.Header()
h0['ID'] = self.beam.id, 'Object ID'
h0['PADX'] = self.beam.pad[1], 'Padding of input image axis1'
h0['PADY'] = self.beam.pad[0], 'Padding of input image axis2'
h0['BEAM'] = self.beam.beam, 'Grism order ("beam")'
h0['XCENTER'] = (self.beam.xcenter,
'Offset of centroid wrt thumb center')
h0['YCENTER'] = (self.beam.ycenter,
'Offset of centroid wrt thumb center')
if hasattr(self.beam, 'yoffset'):
h0['TYOFFSET'] = (self.beam.yoffset,
'Cross dispersion offset of the trace')
if hasattr(self.beam, 'xoffset'):
h0['TXOFFSET'] = (self.beam.xoffset,
'Dispersion offset of the trace')
h0['GPARENT'] = (self.grism.parent_file,
'Parent grism file')
h0['DPARENT'] = (self.direct.parent_file,
'Parent direct file')
h0['FWCPOS'] = (self.grism.fwcpos,
'Filter wheel position (NIRISS)')
h0['MW_EBV'] = (self.grism.MW_EBV,
'Milky Way exctinction E(B-V)')
hdu = pyfits.HDUList([pyfits.PrimaryHDU(header=h0)])
hdu.extend(self.direct.get_HDUList(extver=1))
hdu.append(pyfits.ImageHDU(data=np.asarray(self.beam.seg,dtype=np.int32),
header=hdu[-1].header, name='SEG'))
# 2D grism spectra
grism_hdu = self.grism.get_HDUList(extver=2)
#######
# 2D Spectroscopic WCS
hdu2d, wcs2d = self.get_2d_wcs()
# Get available 'WCSNAME'+key
for key in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
if 'WCSNAME{0}'.format(key) not in self.grism.header:
break
else:
wcsname = self.grism.header['WCSNAME{0}'.format(key)]
if wcsname == 'BeamLinear2D':
break
# h2d = wcs2d.to_header(key=key)
h2d = utils.to_header(wcs2d, key=key)
for ext in grism_hdu:
for k in h2d:
if k not in ext.header:
ext.header[k] = h2d[k], h2d.comments[k]
####
hdu.extend(grism_hdu)
hdu.append(pyfits.ImageHDU(data=self.contam, header=hdu[-1].header,
name='CONTAM'))
if include_model:
hdu.append(pyfits.ImageHDU(data=np.asarray(self.model,dtype=np.float32),
header=hdu[-1].header, name='MODEL'))
if get_trace_table:
trace_hdu = pyfits.table_to_hdu(self.trace_table)
trace_hdu.header['EXTNAME'] = 'TRACE'
trace_hdu.header['EXTVER'] = 2
hdu.append(trace_hdu)
if strip:
# Blotted reference is attached, don't need individual direct
# arrays.
if self.direct['REF'] is not None:
for ext in [('SCI', 1), ('ERR', 1), ('DQ', 1)]:
if ext in hdu:
ix = hdu.index_of(ext)
p = hdu.pop(ix)
# This can be regenerated
# if strip & 2:
# ix = hdu.index_of('MODEL')
# p = hdu.pop(ix)
# Put Primary keywords in first extension
SKIP_KEYS = ['EXTEND', 'SIMPLE']
for key in h0:
if key not in SKIP_KEYS:
hdu[1].header[key] = (h0[key], h0.comments[key])
hdu['SCI', 2].header[key] = (h0[key], h0.comments[key])
if get_hdu:
return hdu
outfile = '{0}_{1:05d}.{2}.{3}.fits'.format(root, self.beam.id,
self.grism.filter.lower(),
self.beam.beam)
hdu.writeto(outfile, overwrite=overwrite)
return outfile
[docs] def compute_model(self, use_psf=True, **kwargs):
"""Link to `self.beam.compute_model`
`self.beam` is a `GrismDisperser` object.
"""
if use_psf & hasattr(self.beam, 'psf'):
result = self.beam.compute_model_psf(**kwargs)
else:
result = self.beam.compute_model(**kwargs)
reset_inplace = True
if 'in_place' in kwargs:
reset_inplace = kwargs['in_place']
if reset_inplace:
self.modelf = self.beam.modelf # .flatten()
self.model = self.beam.modelf.reshape(self.beam.sh_beam)
return result
[docs] def get_wavelength_wcs(self, wavelength=1.3e4):
"""Compute *celestial* WCS of the 2D spectrum array for a specified central wavelength
This essentially recenters the celestial SIP WCS such that the
desired wavelength was at the object position as observed in the
direct image (which has associated geometric distortions etc).
Parameters
----------
wavelength : float
Central wavelength to use for derived WCS.
Returns
-------
header : `~astropy.io.fits.Header`
FITS header
wcs : `~astropy.wcs.WCS`
Derived celestial WCS
"""
wcs = self.grism.wcs.deepcopy()
xarr = np.arange(self.beam.lam_beam.shape[0])
# Trace properties at desired wavelength
dx = np.interp(wavelength, self.beam.lam_beam, xarr)
dy = np.interp(wavelength, self.beam.lam_beam, self.beam.ytrace_beam) + 1
dl = np.interp(wavelength, self.beam.lam_beam[1:],
np.diff(self.beam.lam_beam))
ysens = np.interp(wavelength, self.beam.lam_beam,
self.beam.sensitivity_beam)
# Update CRPIX
dc = 0 # python array center to WCS pixel center
# dc = 1.0 # 0.5
for wcs_ext in [wcs.sip, wcs.wcs]:
if wcs_ext is None:
continue
else:
cr = wcs_ext.crpix
cr[0] += dx + self.beam.x0[1] + self.beam.dxfull[0] + dc
cr[1] += dy + dc
for wcs_ext in [wcs.cpdis1, wcs.cpdis2, wcs.det2im1, wcs.det2im2]:
if wcs_ext is None:
continue
else:
cr = wcs_ext.crval
cr[0] += dx + self.beam.sh[0]/2 + self.beam.dxfull[0] + dc
cr[1] += dy + dc
# Make SIP CRPIX match CRPIX
# if wcs.sip is not None:
# for i in [0,1]:
# wcs.sip.crpix[i] = wcs.wcs.crpix[i]
for wcs_ext in [wcs.sip]:
if wcs_ext is not None:
for i in [0, 1]:
wcs_ext.crpix[i] = wcs.wcs.crpix[i]
# WCS header
header = utils.to_header(wcs, relax=True)
for key in header:
if key.startswith('PC'):
header.rename_keyword(key, key.replace('PC', 'CD'))
header['LONPOLE'] = 180.
header['RADESYS'] = 'ICRS'
header['LTV1'] = (0.0, 'offset in X to subsection start')
header['LTV2'] = (0.0, 'offset in Y to subsection start')
header['LTM1_1'] = (1.0, 'reciprocal of sampling rate in X')
header['LTM2_2'] = (1.0, 'reciprocal of sampling rate in X')
header['INVSENS'] = (ysens, 'inverse sensitivity, 10**-17 erg/s/cm2')
header['DLDP'] = (dl, 'delta wavelength per pixel')
return header, wcs
[docs] def get_2d_wcs(self, data=None, key=None):
"""Get simplified WCS of the 2D spectrum
Parameters
----------
data : array-like
Put this data in the output HDU rather than empty zeros
key : None
Key for WCS extension, passed to `~astropy.wcs.WCS.to_header`.
Returns
-------
hdu : `~astropy.io.fits.ImageHDU`
Image HDU with header and data properties.
wcs : `~astropy.wcs.WCS`
WCS appropriate for the 2D spectrum with spatial (y) and spectral
(x) axes.
.. note::
Assumes linear dispersion and trace functions!
"""
h = pyfits.Header()
h['WCSNAME'] = 'BeamLinear2D'
h['CRPIX1'] = self.beam.sh_beam[0]/2 - self.beam.xcenter
h['CRPIX2'] = self.beam.sh_beam[0]/2 - self.beam.ycenter
# Wavelength, A
h['CNAME1'] = 'Wave-Angstrom'
h['CTYPE1'] = 'WAVE'
#h['CUNIT1'] = 'Angstrom'
h['CRVAL1'] = self.beam.lam_beam[0]
h['CD1_1'] = self.beam.lam_beam[1] - self.beam.lam_beam[0]
h['CD1_2'] = 0.
# Linear trace
h['CNAME2'] = 'Trace'
h['CTYPE2'] = 'LINEAR'
h['CRVAL2'] = -1*self.beam.ytrace_beam[0]
h['CD2_2'] = 1.
h['CD2_1'] = -(self.beam.ytrace_beam[1] - self.beam.ytrace_beam[0])
if data is None:
data = np.zeros(self.beam.sh_beam, dtype=np.float32)
hdu = pyfits.ImageHDU(data=data, header=h)
wcs = pywcs.WCS(hdu.header)
#wcs.pscale = np.sqrt(wcs.wcs.cd[0,0]**2 + wcs.wcs.cd[1,0]**2)*3600.
wcs.pscale = utils.get_wcs_pscale(wcs)
return hdu, wcs
[docs] def full_2d_wcs(self, data=None):
"""Get trace WCS of the 2D spectrum
Parameters
----------
data : array-like
Put this data in the output HDU rather than empty zeros
Returns
-------
hdu : `~astropy.io.fits.ImageHDU`
Image HDU with header and data properties.
wcs : `~astropy.wcs.WCS`
WCS appropriate for the 2D spectrum with spatial (y) and spectral
(x) axes.
.. note::
Assumes linear dispersion and trace functions!
"""
h = pyfits.Header()
h['CRPIX1'] = self.beam.sh_beam[0]/2 - self.beam.xcenter
h['CRPIX2'] = self.beam.sh_beam[0]/2 - self.beam.ycenter
h['CRVAL1'] = self.beam.lam_beam[0]/1.e4
h['CD1_1'] = (self.beam.lam_beam[1] - self.beam.lam_beam[0])/1.e4
h['CD1_2'] = 0.
h['CRVAL2'] = -1*self.beam.ytrace_beam[0]
h['CD2_2'] = 1.
h['CD2_1'] = -(self.beam.ytrace_beam[1] - self.beam.ytrace_beam[0])
h['CTYPE1'] = 'RA---TAN-SIP'
h['CUNIT1'] = 'mas'
h['CTYPE2'] = 'DEC--TAN-SIP'
h['CUNIT2'] = 'mas'
#wcs_header = grizli.utils.to_header(self.grism.wcs)
x = np.arange(len(self.beam.lam_beam))
c = np.polynomial.Polynomial.fit(x,self.beam.lam_beam/1.e4,deg=2).convert().coef[::-1]
#c = np.polynomial.Polynomial.fit((self.beam.lam_beam-self.beam.lam_beam[0])/1.e4, x/h['CD1_1'],deg=2).convert().coef[::-1]
ct = np.polynomial.Polynomial.fit(x,self.beam.ytrace_beam,deg=2).convert().coef[::-1]
h['A_ORDER'] = 2
h['B_ORDER'] = 2
h['A_0_2'] = 0.
h['A_1_2'] = 0.
h['A_2_2'] = 0.
h['A_2_1'] = 0.
h['A_2_0'] = c[0] # /c[1]
h['CD1_1'] = c[1]
h['B_0_2'] = 0.
h['B_1_2'] = 0.
h['B_2_2'] = 0.
h['B_2_1'] = 0.
if ct[1] != 0:
h['B_2_0'] = ct[0] # /ct[1]
else:
h['B_2_0'] = 0
#h['B_2_0'] = 0
if data is None:
data = np.zeros(self.beam.sh_beam, dtype=np.float32)
hdu = pyfits.ImageHDU(data=data, header=h)
wcs = pywcs.WCS(hdu.header)
# xf = x + h['CRPIX1']-1
# coo = np.array([xf, xf*0])
# tr = wcs.all_pix2world(coo.T, 0)
#wcs.pscale = np.sqrt(wcs.wcs.cd[0,0]**2 + wcs.wcs.cd[1,0]**2)*3600.
wcs.pscale = utils.get_wcs_pscale(wcs)
return hdu, wcs
[docs] def get_sky_coords(self):
"""Get WCS coordinates of the center of the direct image
Returns
-------
ra, dec : float
Center coordinates of the beam thumbnail in decimal degrees
"""
pix_center = np.array([self.beam.sh][::-1])/2.
pix_center -= np.array([self.beam.xcenter, self.beam.ycenter])
if self.direct.wcs.sip is not None:
for i in range(2):
self.direct.wcs.sip.crpix[i] = self.direct.wcs.wcs.crpix[i]
ra, dec = self.direct.wcs.all_pix2world(pix_center, 1)[0]
return ra, dec
[docs] def get_dispersion_PA(self, decimals=0, local=False):
"""Compute exact PA of the dispersion axis, including tilt of the
trace and the FLT WCS
Parameters
----------
decimals : int or None
Number of decimal places to round to, passed to `~numpy.round`.
If None, then don't round.
local : bool
Compute local PA of a given beam, otherwise use global WCS
Returns
-------
dispersion_PA : float
PA (angle East of North) of the increasing-wavelength dispersion axis.
"""
from astropy.coordinates import Angle
import astropy.units as u
# extra tilt of the 1st order grism spectra
if 'BEAMA' in self.beam.conf.conf_dict:
x0 = self.beam.conf.conf_dict['BEAMA']
else:
x0 = np.array([10,30])
if local:
xp = self.beam.xc - self.beam.pad[1]
yp = self.beam.yc - self.beam.pad[0]
else:
xp = yp = 507 # Dummy, WFC3/IR center
x0 = np.mean(x0) + np.array([-10,10])
dy_trace, lam_trace = self.beam.conf.get_beam_trace(x=xp, y=yp, dx=x0, beam='A')
# Distorted WCS
crpix = self.direct.wcs.wcs.crpix
if local:
xref = -crpix[0] + x0
yref = [-crpix[1]+dy_trace[0], -crpix[1]+dy_trace[1]]
else:
xref = crpix[0] + x0
yref = [crpix[1]+dy_trace[0], crpix[1]+dy_trace[1]]
r, d = self.direct.wcs.all_pix2world(xref, yref, 1)
dra = np.diff(r)*np.cos(d[0]/180*np.pi)
dde = np.diff(d)[0]
pa = Angle((np.arctan2(dra, dde)/np.pi*180)*u.deg)
dispersion_PA = pa.wrap_at(360*u.deg).value
if decimals is not None:
dispersion_PA = np.round(dispersion_PA, decimals=decimals)
return dispersion_PA[0]
[docs] def init_epsf(self, center=None, tol=1.e-3, yoff=0., skip=1., flat_sensitivity=False, psf_params=None, N=4, get_extended=False, only_centering=True):
"""Initialize ePSF fitting for point sources
TBD
"""
import scipy.sparse
EPSF = utils.EffectivePSF()
ivar = 1/self.direct['ERR']**2
ivar[~np.isfinite(ivar)] = 0
ivar[self.direct['DQ'] > 0] = 0
ivar[self.beam.seg != self.id] = 0
if ivar.max() == 0:
ivar = ivar+1.
origin = np.array(self.direct.origin) - np.array(self.direct.pad)
if psf_params is None:
self.beam.psf_ivar = ivar*1
self.beam.psf_sci = self.direct['SCI']*1
self.psf_params = EPSF.fit_ePSF(self.direct['SCI'],
ivar=ivar,
center=center, tol=tol,
N=N, origin=origin,
filter=self.direct.filter,
get_extended=get_extended,
only_centering=only_centering)
else:
self.beam.psf_ivar = ivar*1
self.beam.psf_sci = self.direct['SCI']*1
self.psf_params = psf_params
self.beam.x_init_epsf(flat_sensitivity=False, psf_params=self.psf_params, psf_filter=self.direct.filter, yoff=yoff, skip=skip, get_extended=get_extended)
self._parse_from_data(**self._parse_params)
return None
# self.psf = EPSF.get_ePSF(self.psf_params, origin=origin, shape=self.beam.sh, filter=self.direct.filter)
#
# self.psf_resid = self.direct['SCI'] - self.psf
#
# y0, x0 = np.array(self.beam.sh)/2.-1
#
# # Center in detector coords
# xd = self.psf_params[1] + self.direct.origin[1] - self.direct.pad + x0
# yd = self.psf_params[2] + self.direct.origin[0] - self.direct.pad + y0
#
# # Get wavelength array
# psf_xy_lam = []
# for i, filter in enumerate(['F105W', 'F125W', 'F160W']):
# psf_xy_lam.append(EPSF.get_at_position(x=xd, y=yd, filter=filter))
#
# filt_ix = np.arange(3)
# filt_lam = np.array([1.0551, 1.2486, 1.5369])*1.e4
#
# yp_beam, xp_beam = np.indices(self.beam.sh_beam)
# #skip = 1
# xarr = np.arange(0,self.beam.lam_beam.shape[0], skip)
# xarr = xarr[xarr <= self.beam.lam_beam.shape[0]-1]
# xbeam = np.arange(self.beam.lam_beam.shape[0])*1.
#
# #yoff = 0 #-0.15
# psf_model = self.model*0.
# A_psf = []
# lam_psf = []
#
# lam_offset = self.beam.sh[1]/2 - self.psf_params[1] - 1
# self.lam_offset = lam_offset
#
# for xi in xarr:
# yi = np.interp(xi, xbeam, self.beam.ytrace_beam)
# li = np.interp(xi, xbeam, self.beam.lam_beam)
# dx = xp_beam-self.psf_params[1]-xi-x0
# dy = yp_beam-self.psf_params[2]-yi+yoff-y0
#
# # wavelength-dependent
# ii = np.interp(li, filt_lam, filt_ix, left=-1, right=10)
# if ii == -1:
# psf_xy_i = psf_xy_lam[0]*1
# elif ii == 10:
# psf_xy_i = psf_xy_lam[2]*1
# else:
# ni = int(ii)
# f = 1-(li-filt_lam[ni])/(filt_lam[ni+1]-filt_lam[ni])
# psf_xy_i = f*psf_xy_lam[ni] + (1-f)*psf_xy_lam[ni+1]
#
# psf = EPSF.eval_ePSF(psf_xy_i, dx, dy)*self.psf_params[0]
#
# A_psf.append(psf.flatten())
# lam_psf.append(li)
#
# # Sensitivity
# self.lam_psf = np.array(lam_psf)
# if flat_sensitivity:
# s_i_scale = np.abs(np.gradient(self.lam_psf))*self.direct.photflam
# else:
# sens = self.beam.conf.sens[self.beam.beam]
# so = np.argsort(self.lam_psf)
# s_i = interp.interp_conserve_c(self.lam_psf[so], sens['WAVELENGTH'], sens['SENSITIVITY'])*np.gradient(self.lam_psf[so])*self.direct.photflam
# s_i_scale = s_i*0.
# s_i_scale[so] = s_i
#
# self.A_psf = scipy.sparse.csr_matrix(np.array(A_psf).T*s_i_scale)
# def xcompute_model_psf(self, id=None, spectrum_1d=None, in_place=True, is_cgs=True):
# if spectrum_1d is None:
# model = np.array(self.A_psf.sum(axis=1))
# model = model.reshape(self.beam.sh_beam)
# else:
# dx = np.diff(self.lam_psf)[0]
# if dx < 0:
# coeffs = interp.interp_conserve_c(self.lam_psf[::-1],
# spectrum_1d[0],
# spectrum_1d[1])[::-1]
# else:
# coeffs = interp.interp_conserve_c(self.lam_psf,
# spectrum_1d[0],
# spectrum_1d[1])
#
#
# model = self.A_psf.dot(coeffs).reshape(self.beam.sh_beam)
#
# if in_place:
# self.model = model
# self.beam.model = self.model
# return True
# else:
# return model.flatten()
# Below here will be cut out after verifying that the demos
# can be run with the new fitting tools
[docs] def init_poly_coeffs(self, poly_order=1, fit_background=True):
"""Initialize arrays for polynomial fits to the spectrum
Provides capabilities of fitting n-order polynomials to observed
spectra rather than galaxy/stellar templates.
Parameters
----------
poly_order : int
Order of the polynomial
fit_background : bool
Compute additional arrays for allowing the background to be fit
along with the polynomial coefficients.
Returns
-------
Polynomial parameters stored in attributes `y_poly`, `n_poly`, ...
"""
# Already done?
if poly_order == self.poly_order:
return None
self.poly_order = poly_order
# Model: (a_0 x**0 + ... a_i x**i)*continuum + line
yp, xp = np.indices(self.beam.sh_beam)
NX = self.beam.sh_beam[1]
self.xpf = (xp.flatten() - NX/2.)
self.xpf /= (NX/2.)
# Polynomial continuum arrays
if fit_background:
self.n_bg = 1
self.A_poly = [self.flat_flam*0+1]
self.A_poly.extend([self.xpf**order*self.flat_flam
for order in range(poly_order+1)])
else:
self.n_bg = 0
self.A_poly = [self.xpf**order*self.flat_flam
for order in range(poly_order+1)]
# Array for generating polynomial "template"
x = (np.arange(NX) - NX/2.) / (NX/2.)
self.y_poly = np.array([x**order for order in range(poly_order+1)])
self.n_poly = self.y_poly.shape[0]
self.n_simp = self.n_poly + self.n_bg
self.DoF = self.fit_mask.sum()
[docs] def show_redshift_fit(self, fit_data):
"""Make a plot based on results from `simple_line_fit`.
Parameters
----------
fit_data : dict
returned data from `simple_line_fit`. I.e.,
>>> fit_outputs = BeamCutout.simple_line_fit()
>>> fig = BeamCutout.show_simple_fit_results(fit_outputs)
Returns
-------
fig : `~matplotlib.figure.Figure`
Figure object that can be optionally written to a hardcopy file.
"""
import matplotlib.gridspec
#zgrid, A, coeffs, chi2, model_best, model_continuum, model1d = fit_outputs
# Full figure
fig = plt.figure(figsize=(12, 5))
#fig = plt.Figure(figsize=(8,4))
# 1D plots
gsb = matplotlib.gridspec.GridSpec(3, 1)
xspec, yspec, yerr = self.beam.optimal_extract(self.grism.data['SCI']
- self.contam,
ivar=self.ivar)
flat_model = self.flat_flam.reshape(self.beam.sh_beam)
xspecm, yspecm, yerrm = self.beam.optimal_extract(flat_model)
out = self.beam.optimal_extract(fit_data['model_full'])
xspecl, yspecl, yerrl = out
ax = fig.add_subplot(gsb[-2:, :])
ax.errorbar(xspec/1.e4, yspec, yerr, linestyle='None', marker='o',
markersize=3, color='black', alpha=0.5,
label='Data (id={0:d})'.format(self.beam.id))
ax.plot(xspecm/1.e4, yspecm, color='red', linewidth=2, alpha=0.8,
label=r'Flat $f_\lambda$ ({0})'.format(self.direct.filter))
zbest = fit_data['zgrid'][np.argmin(fit_data['chi2'])]
ax.plot(xspecl/1.e4, yspecl, color='orange', linewidth=2, alpha=0.8,
label='Template (z={0:.4f})'.format(zbest))
ax.legend(fontsize=8, loc='lower center', scatterpoints=1)
ax.set_xlabel(r'$\lambda$')
ax.set_ylabel('flux (e-/s)')
if self.grism.filter == 'G102':
xlim = [0.7, 1.25]
if self.grism.filter == 'G141':
xlim = [1., 1.8]
xt = np.arange(xlim[0], xlim[1], 0.1)
ax.set_xlim(xlim[0], xlim[1])
ax.set_xticks(xt)
ax = fig.add_subplot(gsb[-3, :])
ax.plot(fit_data['zgrid'], fit_data['chi2']/self.DoF)
for d in [1, 4, 9]:
ax.plot(fit_data['zgrid'],
fit_data['chi2']*0+(fit_data['chi2'].min()+d)/self.DoF,
color='{0:.1f}'.format(d/20.))
# ax.set_xticklabels([])
ax.set_ylabel(r'$\chi^2/(\nu={0:d})$'.format(self.DoF))
ax.set_xlabel('z')
ax.set_xlim(fit_data['zgrid'][0], fit_data['zgrid'][-1])
# axt = ax.twiny()
# axt.set_xlim(np.array(ax.get_xlim())*1.e4/6563.-1)
# axt.set_xlabel(r'$z_\mathrm{H\alpha}$')
# 2D spectra
gst = matplotlib.gridspec.GridSpec(4, 1)
if 'viridis_r' in plt.colormaps():
cmap = 'viridis_r'
else:
cmap = 'cubehelix_r'
ax = fig.add_subplot(gst[0, :])
ax.imshow(self.grism.data['SCI'], vmin=-0.05, vmax=0.2, cmap=cmap,
interpolation='Nearest', origin='lower', aspect='auto')
ax.set_ylabel('Observed')
ax = fig.add_subplot(gst[1, :])
mask2d = self.fit_mask.reshape(self.beam.sh_beam)
ax.imshow((self.grism.data['SCI'] - self.contam)*mask2d,
vmin=-0.05, vmax=0.2, cmap=cmap,
interpolation='Nearest', origin='lower', aspect='auto')
ax.set_ylabel('Masked')
ax = fig.add_subplot(gst[2, :])
ax.imshow(fit_data['model_full']+self.contam, vmin=-0.05, vmax=0.2,
cmap=cmap, interpolation='Nearest', origin='lower',
aspect='auto')
ax.set_ylabel('Model')
ax = fig.add_subplot(gst[3, :])
ax.imshow(self.grism.data['SCI']-fit_data['model_full']-self.contam,
vmin=-0.05, vmax=0.2, cmap=cmap, interpolation='Nearest',
origin='lower', aspect='auto')
ax.set_ylabel('Resid.')
for ax in fig.axes[-4:]:
self.beam.twod_axis_labels(wscale=1.e4,
limits=[xlim[0], xlim[1], 0.1],
mpl_axis=ax)
self.beam.twod_xlim(xlim, wscale=1.e4, mpl_axis=ax)
ax.set_yticklabels([])
ax.set_xlabel(r'$\lambda$')
for ax in fig.axes[-4:-1]:
ax.set_xticklabels([])
gsb.tight_layout(fig, pad=0.1, h_pad=0.01, rect=(0, 0, 0.5, 1))
gst.tight_layout(fig, pad=0.1, h_pad=0.01, rect=(0.5, 0.01, 1, 0.98))
return fig
[docs] def simple_line_fit(self, fwhm=48., grid=[1.12e4, 1.65e4, 1, 4],
fitter='lstsq', poly_order=3):
"""Function to fit a Gaussian emission line and a polynomial continuum
Parameters
----------
fwhm : float
FWHM of the emission line
grid : list `[l0, l1, dl, skip]`
The base wavelength array will be generated like
>>> wave = np.arange(l0, l1, dl)
and lines will be generated every `skip` wavelength grid points:
>>> line_centers = wave[::skip]
fitter : str, 'lstsq' or 'sklearn'
Least-squares fitting function for determining template
normalization coefficients.
order : int (>= 0)
Polynomial order to use for the continuum
Returns
-------
line_centers : length N `~numpy.array`
emission line center positions
coeffs : (N, M) `~numpy.ndarray` where `M = (poly_order+1+1)`
Normalization coefficients for the continuum and emission line
templates.
chi2 : `~numpy.array`
Chi-squared evaluated at each line_centers[i]
ok_data : `~numpy.ndarray`
Boolean mask of pixels used for the Chi-squared calculation.
Consists of non-masked DQ pixels, non-zero ERR pixels and pixels
where `self.model > 0.03*self.model.max()` for the flat-spectrum
model.
best_model : `~numpy.ndarray`
2D array with best-fit continuum + line model
best_model_cont : `~numpy.ndarray`
2D array with Best-fit continuum-only model.
best_line_center : float
wavelength where chi2 is minimized.
best_line_flux : float
Emission line flux where chi2 is minimized
"""
# Test fit
import sklearn.linear_model
import numpy.linalg
clf = sklearn.linear_model.LinearRegression()
# Continuum
self.compute_model()
self.model = self.modelf.reshape(self.beam.sh_beam)
# OK data where the 2D model has non-zero flux
ok_data = (~self.mask.flatten()) & (self.ivar.flatten() != 0)
ok_data &= (self.modelf > 0.03*self.modelf.max())
# Flat versions of sci/ivar arrays
scif = (self.grism.data['SCI'] - self.contam).flatten()
ivarf = self.ivar.flatten()
# Model: (a_0 x**0 + ... a_i x**i)*continuum + line
yp, xp = np.indices(self.beam.sh_beam)
xpf = (xp.flatten() - self.beam.sh_beam[1]/2.)
xpf /= (self.beam.sh_beam[1]/2)
# Polynomial continuum arrays
A_list = [xpf**order*self.modelf for order in range(poly_order+1)]
# Extra element for the computed line model
A_list.append(self.modelf*1)
A = np.vstack(A_list).T
# Normalized Gaussians on a grid
waves = np.arange(grid[0], grid[1], grid[2])
line_centers = waves[grid[3] // 2::grid[3]]
rms = fwhm/2.35
gaussian_lines = np.exp(-(line_centers[:, None]-waves)**2/2/rms**2)
gaussian_lines /= np.sqrt(2*np.pi*rms**2)
N = len(line_centers)
coeffs = np.zeros((N, A.shape[1]))
chi2 = np.zeros(N)
chi2min = 1e30
# Loop through line models and fit for template coefficients
# Compute chi-squared.
for i in range(N):
self.compute_model(spectrum_1d=[waves, gaussian_lines[i, :]])
A[:, -1] = self.model.flatten()
if fitter == 'lstsq':
out = np.linalg.lstsq(A[ok_data, :], scif[ok_data],
rcond=utils.LSTSQ_RCOND)
lstsq_coeff, residuals, rank, s = out
coeffs[i, :] += lstsq_coeff
model = np.dot(A, lstsq_coeff)
else:
status = clf.fit(A[ok_data, :], scif[ok_data])
coeffs[i, :] = clf.coef_
model = np.dot(A, clf.coef_)
chi2[i] = np.sum(((scif-model)**2*ivarf)[ok_data])
if chi2[i] < chi2min:
chi2min = chi2[i]
# print chi2
ix = np.argmin(chi2)
self.compute_model(spectrum_1d=[waves, gaussian_lines[ix, :]])
A[:, -1] = self.model.flatten()
best_coeffs = coeffs[ix, :]*1
best_model = np.dot(A, best_coeffs).reshape(self.beam.sh_beam)
# Continuum
best_coeffs_cont = best_coeffs*1
best_coeffs_cont[-1] = 0.
best_model_cont = np.dot(A, best_coeffs_cont)
best_model_cont = best_model_cont.reshape(self.beam.sh_beam)
best_line_center = line_centers[ix]
best_line_flux = coeffs[ix, -1]*self.beam.total_flux/1.e-17
return (line_centers, coeffs, chi2, ok_data,
best_model, best_model_cont,
best_line_center, best_line_flux)
[docs] def show_simple_fit_results(self, fit_outputs):
"""Make a plot based on results from `simple_line_fit`.
Parameters
----------
fit_outputs : tuple
returned data from `simple_line_fit`. I.e.,
>>> fit_outputs = BeamCutout.simple_line_fit()
>>> fig = BeamCutout.show_simple_fit_results(fit_outputs)
Returns
-------
fig : `~matplotlib.figure.Figure`
Figure object that can be optionally written to a hardcopy file.
"""
import matplotlib.gridspec
line_centers, coeffs, chi2, ok_data, best_model, best_model_cont, best_line_center, best_line_flux = fit_outputs
# Full figure
fig = plt.figure(figsize=(10, 5))
#fig = plt.Figure(figsize=(8,4))
# 1D plots
gsb = matplotlib.gridspec.GridSpec(3, 1)
xspec, yspec, yerr = self.beam.optimal_extract(self.grism.data['SCI']
- self.contam,
ivar=self.ivar)
flat_model = self.compute_model(in_place=False)
flat_model = flat_model.reshape(self.beam.sh_beam)
xspecm, yspecm, yerrm = self.beam.optimal_extract(flat_model)
xspecl, yspecl, yerrl = self.beam.optimal_extract(best_model)
ax = fig.add_subplot(gsb[-2:, :])
ax.errorbar(xspec/1.e4, yspec, yerr, linestyle='None', marker='o',
markersize=3, color='black', alpha=0.5,
label='Data (id={0:d})'.format(self.beam.id))
ax.plot(xspecm/1.e4, yspecm, color='red', linewidth=2, alpha=0.8,
label=r'Flat $f_\lambda$ ({0})'.format(self.direct.filter))
ax.plot(xspecl/1.e4, yspecl, color='orange', linewidth=2, alpha=0.8,
label='Cont+line ({0:.4f}, {1:.2e})'.format(best_line_center/1.e4, best_line_flux*1.e-17))
ax.legend(fontsize=8, loc='lower center', scatterpoints=1)
ax.set_xlabel(r'$\lambda$')
ax.set_ylabel('flux (e-/s)')
ax = fig.add_subplot(gsb[-3, :])
ax.plot(line_centers/1.e4, chi2/ok_data.sum())
ax.set_xticklabels([])
ax.set_ylabel(r'$\chi^2/(\nu={0:d})$'.format(ok_data.sum()))
if self.grism.filter == 'G102':
xlim = [0.7, 1.25]
if self.grism.filter == 'G141':
xlim = [1., 1.8]
xt = np.arange(xlim[0], xlim[1], 0.1)
for ax in fig.axes:
ax.set_xlim(xlim[0], xlim[1])
ax.set_xticks(xt)
axt = ax.twiny()
axt.set_xlim(np.array(ax.get_xlim())*1.e4/6563.-1)
axt.set_xlabel(r'$z_\mathrm{H\alpha}$')
# 2D spectra
gst = matplotlib.gridspec.GridSpec(3, 1)
if 'viridis_r' in plt.colormaps():
cmap = 'viridis_r'
else:
cmap = 'cubehelix_r'
ax = fig.add_subplot(gst[0, :])
ax.imshow(self.grism.data['SCI'], vmin=-0.05, vmax=0.2, cmap=cmap,
interpolation='Nearest', origin='lower', aspect='auto')
ax.set_ylabel('Observed')
ax = fig.add_subplot(gst[1, :])
ax.imshow(best_model+self.contam, vmin=-0.05, vmax=0.2, cmap=cmap,
interpolation='Nearest', origin='lower', aspect='auto')
ax.set_ylabel('Model')
ax = fig.add_subplot(gst[2, :])
ax.imshow(self.grism.data['SCI']-best_model-self.contam, vmin=-0.05,
vmax=0.2, cmap=cmap, interpolation='Nearest',
origin='lower', aspect='auto')
ax.set_ylabel('Resid.')
for ax in fig.axes[-3:]:
self.beam.twod_axis_labels(wscale=1.e4,
limits=[xlim[0], xlim[1], 0.1],
mpl_axis=ax)
self.beam.twod_xlim(xlim, wscale=1.e4, mpl_axis=ax)
ax.set_yticklabels([])
ax.set_xlabel(r'$\lambda$')
for ax in fig.axes[-3:-1]:
ax.set_xticklabels([])
gsb.tight_layout(fig, pad=0.1, h_pad=0.01, rect=(0, 0, 0.5, 1))
gst.tight_layout(fig, pad=0.1, h_pad=0.01, rect=(0.5, 0.1, 1, 0.9))
return fig