"""
Functionality for manipulating multiple grism exposures simultaneously
"""
import os
import time
import traceback
import glob
from collections import OrderedDict
import multiprocessing as mp
from . import prep
import scipy.ndimage as nd
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table
import astropy.io.fits as pyfits
import astropy.wcs as pywcs
import astropy.units as u
# local imports
from . import utils
from . import model
from . import grismconf
# from . import stack
from .fitting import GroupFitter
from .utils_numba import interp
from .utils import GRISM_COLORS, GRISM_MAJOR, GRISM_LIMITS, DEFAULT_LINE_LIST
def _loadFLT(
grism_file,
sci_extn,
direct_file,
pad,
ref_file,
ref_ext,
seg_file,
verbose,
catalog,
ix,
use_jwst_crds,
):
"""
Helper function for loading `.model.GrismFLT` objects.
Parameters
----------
grism_file : str
Filename of the grism exposure.
sci_extn : int
Science extension to extract from the files in `grism_files`. For
WFC3/IR this can only be 1, though for the two-chip instruments
WFC3/UVIS and ACS/WFC3 this can be 1 or 2.
direct_file : str
Filename of the direct image associated with the grism exposure.
pad : tuple
Padding in pixels to apply around the edge of the detector to
allow modeling of sources that fall off of the nominal FOV. For
this to work requires using a `ref_file` reference image that
covers this extra area. Specified in array axis order (pady, padx).
ref_file : `None` or str
Undistorted reference image filename, e.g., a drizzled mosaic
covering the area around a given grism exposure.
ref_ext : int
FITS extension of the reference file where to find the image
itself.
seg_file : `None` or str
Segmentation image filename.
verbose : bool
Print verbose information to the terminal.
catalog : `~astropy.table.Table`
The table read in with from the above file specified in `catalog`.
ix : int
Index of the grism exposure (not used).
use_jwst_crds : bool
Use the JWST CRDS system to get the WCS information.
Returns
-------
flt : `~grizli.model.GrismFLT`
The grism exposure object.
"""
import time
try:
import cPickle as pickle
except:
# Python 3
import pickle
# slight random delay to avoid synchronization problems
# np.random.seed(ix)
# sleeptime = ix*1
# print '%s sleep %.3f %d' %(grism_file, sleeptime, ix)
# time.sleep(sleeptime)
# print grism_file, direct_file
new_root = ".{0:02d}.GrismFLT.fits".format(sci_extn)
save_file = grism_file.replace("_flt.fits", new_root)
save_file = save_file.replace("_flc.fits", new_root)
save_file = save_file.replace("_cmb.fits", new_root)
save_file = save_file.replace("_rate.fits", new_root)
save_file = save_file.replace("_elec.fits", new_root)
if (save_file == grism_file) & ("GrismFLT" not in grism_file):
# couldn't build new filename based on the extensions
# so just insert at the end
save_file = grism_file.replace(".fits", new_root)
if (grism_file.find("_") < 0) & ("GrismFLT" not in grism_file):
save_file = "xxxxxxxxxxxxxxxxxxx"
if os.path.exists(save_file) & ("GrismFLT" in save_file):
print("Load {0}!".format(save_file))
fp = open(save_file.replace("GrismFLT.fits", "GrismFLT.pkl"), "rb")
flt = pickle.load(fp)
fp.close()
status = flt.load_from_fits(save_file)
else:
flt = model.GrismFLT(
grism_file=grism_file,
sci_extn=sci_extn,
direct_file=direct_file,
pad=pad,
ref_file=ref_file,
ref_ext=ref_ext,
seg_file=seg_file,
shrink_segimage=True,
verbose=verbose,
use_jwst_crds=use_jwst_crds,
)
if flt.direct.wcs.wcs.has_pc():
for obj in [flt.grism, flt.direct]:
obj.get_wcs()
if catalog is not None:
flt.catalog = flt.blot_catalog(
catalog, sextractor=("X_WORLD" in catalog.colnames)
)
flt.catalog_file = catalog
else:
flt.catalog = None
if flt.grism.instrument in ["NIRCAM"]:
flt.apply_POM(verbose=verbose)
if flt.grism.instrument in ["NIRISS", "NIRCAM"]:
flt.transform_JWST_WFSS(verbose=verbose)
if hasattr(flt, "conf"):
delattr(flt, "conf")
return flt # , out_cat
def _fit_at_z(self, zgrid, i, templates, fitter, fit_background, poly_order):
"""
For parallel processing of `~grizli.multifit.GroupFLT.fit_at_z`.
Parameters
----------
zgrid : list
Redshift grid.
i : int
Index of the redshift grid to fit.
templates : dict
Template dictionary.
fitter : str
Fitting method.
fit_background : bool
Fit a polynomial background.
poly_order : int
Order of the polynomial background.
Returns
-------
data : dict
Dictionary of the fitting results.
"""
# self, z=0., templates={}, fitter='nnls',
# fit_background=True, poly_order=0
print(i, zgrid[i])
out = self.fit_at_z(
z=zgrid[i],
templates=templates,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
data = {"out": out, "i": i}
return data
# A, coeffs[i,:], chi2[i], model_2d = out
def _beam_compute_model(beam, id, spectrum_1d, is_cgs, apply_sensitivity, scale, reset):
"""
Wrapper function for multiprocessing of `~grizli.model.Beam.compute_model`.
Parameters
----------
beam : `~grizli.model.Beam`
The beam object.
id : int
Object ID.
spectrum_1d : list
1D spectrum to fit.
is_cgs : bool
Spectrum is in cgs units.
apply_sensitivity : bool
Apply the sensitivity function.
scale : float
Scale factor for the spectrum.
reset : bool
Reset the model.
Returns
-------
status : bool
True if the model was computed successfully.
"""
beam.beam.compute_model(
id=id,
spectrum_1d=spectrum_1d,
is_cgs=is_cgs,
scale=scale,
reset=reset,
apply_sensitivity=apply_sensitivity,
)
beam.modelf = beam.beam.modelf
beam.model = beam.beam.modelf.reshape(beam.beam.sh_beam)
return True
# def test_parallel():
#
# zgrid = np.linspace(1.1, 1.3, 10)
# templates = mb.load_templates(fwhm=800)
# fitter = 'nnls'
# fit_background = True
# poly_order = 0
#
# self.FLTs = []
# t0_pool = time.time()
#
# pool = mp.Pool(processes=4)
# results = [pool.apply_async(_fit_at_z, (mb, zgrid, i, templates, fitter, fit_background, poly_order)) for i in range(len(zgrid))]
#
# pool.close()
# pool.join()
#
# chi = zgrid*0.
#
# for res in results:
# data = res.get(timeout=1)
# A, coeffs, chi[data['i']], model_2d = data['out']
# #flt_i.catalog = cat_i
#
# t1_pool = time.time()
def _compute_model(i, flt, fit_info, is_cgs, store, model_kwargs):
"""
Helper function for computing model orders in parallel with `multiprocessing`.
Parameters
----------
i : int
Index of the `~grizli.model.GrismFLT` object in `self.FLTs`.
flt : `~grizli.model.GrismFLT`
The grism exposure object.
fit_info : dict
Dictionary of the fitting information.
is_cgs : bool
Spectrum is in cgs units.
store : bool
Store the model in the `flt` object.
model_kwargs : dict
Keywords to pass to the `~grizli.model.GrismFLT.compute_model_orders`
method of the `~grizli.model.GrismFLT` objects.
Returns
-------
i : int
Index of the `~grizli.model.GrismFLT` object in `self.FLTs`.
flt.model : `~numpy.ndarray`
The model spectrum.
flt.object_dispersers : dict
Dictionary of the object dispersers.
"""
if not hasattr(flt, "conf"):
flt.conf = grismconf.load_grism_config(flt.conf_file)
for id in fit_info:
try:
status = flt.compute_model_orders(
id=id,
mag=fit_info[id]["mag"],
in_place=True,
store=store,
spectrum_1d=fit_info[id]["spec"],
is_cgs=is_cgs,
verbose=False,
**model_kwargs,
)
except:
print("Failed: {0} {1}".format(flt.grism.parent_file, id))
continue
print("{0}: _compute_model Done".format(flt.grism.parent_file))
return i, flt.model, flt.object_dispersers
[docs]class GroupFLT:
def __init__(
self,
grism_files=[],
sci_extn=1,
direct_files=[],
pad=(64, 256),
group_name="group",
ref_file=None,
ref_ext=0,
seg_file=None,
shrink_segimage=True,
verbose=True,
cpu_count=0,
catalog="",
polyx=[0.3, 5.3],
MW_EBV=0.0,
bits=None,
use_jwst_crds=False,
):
"""
Main container for handling multiple grism exposures together
Parameters
----------
grism_files : list
List of grism exposures (typically WFC3/IR "FLT" or ACS/UVIS "FLC"
files). These can be from different grisms and/or orients.
sci_extn : int
Science extension to extract from the files in `grism_files`. For
WFC3/IR this can only be 1, though for the two-chip instruments
WFC3/UVIS and ACS/WFC3 this can be 1 or 2.
direct_files : list
List of direct exposures (typically WFC3/IR "FLT" or ACS/UVIS
"FLC" files). This list should either be empty or should
correspond one-to-one with entries in the `grism_files` list,
i.e., from an undithered pair of direct and grism exposures. If
such pairs weren't obtained or if you simply wish to ignore them
and just use the `ref_file` reference image, set to an empty list
(`[]`).
pad : int, int
Padding in pixels to apply around the edge of the detector to
allow modeling of sources that fall off of the nominal FOV. For
this to work requires using a `ref_file` reference image that
covers this extra area. Specified in array axis order (pady, padx)
group_name : str
Name to apply to products produced by this group.
ref_file : `None` or str
Undistorted reference image filename, e.g., a drizzled mosaic
covering the area around a given grism exposure.
ref_ext : 0
FITS extension of the reference file where to find the image
itself.
seg_file : `None` or str
Segmentation image filename.
shrink_segimage : bool
Do some preprocessing on the segmentation image to speed up the
blotting to the distorted frame of the grism exposures.
Not used.
verbose : bool
Print verbose information.
cpu_count : int
Use parallelization if > 0. If equal to zero, then use the
maximum number of available cores.
catalog : str
Catalog filename assocated with `seg_file`. These are typically
generated with "SExtractor", but the source of the files
themselves isn't critical.
polyx : list
Wavelength range for polynomial fits.
MW_EBV : float
Galactic extinction to apply to the models.
Not used.
bits : int
Bitmask to apply to the segmentation image.
Not used.
use_jwst_crds : bool
Use the JWST CRDS system to get the WCS information.
Attributes
----------
catalog : `~astropy.table.Table`
The table read in with from the above file specified in `catalog`.
FLTs : list
List of `~grizli.model.GrismFLT` objects generated from each of
the files in the `grism_files` list.
grp.N : int
Number of grism files (i.e., `len(FLTs)`.)
"""
N = len(grism_files)
if len(direct_files) != len(grism_files):
direct_files = [""] * N
self.grism_files = grism_files
self.direct_files = direct_files
self.group_name = group_name
# Wavelengths for polynomial fits
self.polyx = polyx
# Read catalog
if catalog:
if isinstance(catalog, str):
self.catalog = utils.GTable.gread(catalog)
else:
self.catalog = catalog
# necessary columns from SExtractor / photutils
pairs = [["NUMBER", "id"], ["MAG_AUTO", "mag"], ["MAGERR_AUTO", "mag_err"]]
cols = self.catalog.colnames
for pair in pairs:
if (pair[0] not in cols) & (pair[1] in cols):
self.catalog[pair[0]] = self.catalog[pair[1]]
else:
self.catalog = None
if cpu_count == 0:
cpu_count = mp.cpu_count()
self.FLTs = []
if cpu_count < 0:
# serial
t0_pool = time.time()
for i in range(N):
flt = _loadFLT(
self.grism_files[i],
sci_extn,
self.direct_files[i],
pad,
ref_file,
ref_ext,
seg_file,
verbose,
self.catalog,
i,
use_jwst_crds,
)
self.FLTs.append(flt)
t1_pool = time.time()
else:
# Read files in parallel
t0_pool = time.time()
pool = mp.Pool(processes=cpu_count)
results = [
pool.apply_async(
_loadFLT,
(
self.grism_files[i],
sci_extn,
self.direct_files[i],
pad,
ref_file,
ref_ext,
seg_file,
verbose,
self.catalog,
i,
use_jwst_crds,
),
)
for i in range(N)
]
pool.close()
pool.join()
for res in results:
flt_i = res.get(timeout=1)
# flt_i.catalog = cat_i
# somehow WCS getting flipped from cd to pc in res.get()???
if flt_i.direct.wcs.wcs.has_pc():
for obj in [flt_i.grism, flt_i.direct]:
obj.get_wcs()
self.FLTs.append(flt_i)
t1_pool = time.time()
# Set conf
for flt_i in self.FLTs:
if not hasattr(flt_i, "conf"):
flt_i.conf = grismconf.load_grism_config(flt_i.conf_file)
if verbose:
print("Files loaded - {0:.2f} sec.".format(t1_pool - t0_pool))
@property
def N(self):
"""
Number of grism files.
"""
return len(self.FLTs)
@property
def Ngrism(self):
"""
Dictionary containing number of exposures by grism.
"""
# Parse grisms & PAs
Ngrism = {}
for flt in self.FLTs:
if flt.grism.instrument == "NIRISS":
grism = flt.grism.pupil
else:
grism = flt.grism.filter
if grism not in Ngrism:
Ngrism[grism] = 0
Ngrism[grism] += 1
return Ngrism
@property
def grisms(self):
"""
Available grisms.
"""
grisms = list(self.Ngrism.keys())
return grisms
@property
def PA(self):
"""
Available PAs in each grism.
"""
_PA = {}
for g in self.Ngrism:
_PA[g] = {}
for i, flt in enumerate(self.FLTs):
if flt.grism.instrument == "NIRISS":
grism = flt.grism.pupil
else:
grism = flt.grism.filter
PA_i = flt.get_dispersion_PA(decimals=0)
if PA_i not in _PA[grism]:
_PA[grism][PA_i] = []
_PA[grism][PA_i].append(i)
return _PA
[docs] def save_full_data(self, warn=True, verbose=False):
"""
Save models and data files for fast regeneration.
The filenames of the outputs are generated from the input grism
exposure filenames with the following:
>>> file = 'ib3701ryq_flt.fits'
>>> sci_extn = 1
>>> new_root = '.{0:02d}.GrismFLT.fits'.format(sci_extn)
>>>
>>> save_file = file.replace('_flt.fits', new_root)
>>> save_file = save_file.replace('_flc.fits', new_root)
>>> save_file = save_file.replace('_cmb.fits', new_root)
>>> save_file = save_file.replace('_rate.fits', new_root)
It will also save data to a `~pickle` file:
>>> pkl_file = save_file.replace('.fits', '.pkl')
Parameters
----------
warn : bool
Print a warning and skip if an output file is already found to
exist.
verbose : bool
Print verbose information to the terminal.
Notes
-----
The save filename format was changed May 9, 2017 to the format like
`ib3701ryq.01.GrismFLT.fits` from `ib3701ryq_GrismFLT.fits` to both
allow easier filename parsing and also to allow for instruments that
have multiple `SCI` extensions in a single calibrated file
(e.g., ACS and WFC3/UVIS).
"""
for _flt in self.FLTs:
file = _flt.grism_file
if _flt.grism.data is None:
if warn:
print("{0}: Looks like data already saved!".format(file))
continue
new_root = ".{0:02d}.GrismFLT.fits".format(_flt.grism.sci_extn)
save_file = file.replace("_flt.fits", new_root)
save_file = save_file.replace("_flc.fits", new_root)
save_file = save_file.replace("_cmb.fits", new_root)
save_file = save_file.replace("_rate.fits", new_root)
save_file = save_file.replace("_elec.fits", new_root)
if (save_file == file) & ("GrismFLT" not in file):
# couldn't build new filename based on the extensions
# so just insert at the end
save_file = file.replace(".fits", new_root)
# Rotate back to detector frame if needed
if _flt.grism.instrument in ["NIRISS", "NIRCAM"]:
_flt.transform_JWST_WFSS(verbose=verbose)
# Save the files
print("Save {0}".format(save_file))
_flt.save_full_pickle()
# Reload initialized data
_flt.load_from_fits(save_file)
# Rotate to "GrismFLT" frame if needed
if _flt.grism.instrument in ["NIRISS", "NIRCAM"]:
_flt.transform_JWST_WFSS(verbose=verbose)
[docs] def extend(self, new, verbose=True):
"""
Add another `GroupFLT` instance to `self`
This function appends the exposures if a separate `GroupFLT` instance
to the current instance. You might do this, for example, if you
generate separate `GroupFLT` instances for different grisms and
reference images with different filters.
Parameters
----------
new : `GroupFLT`
The `GroupFLT` instance to add to the current one.
verbose : bool
Print verbose information to the terminal.
"""
import copy
self.FLTs.extend(new.FLTs)
direct_files = copy.copy(self.direct_files)
direct_files.extend(new.direct_files)
self.direct_files = direct_files
grism_files = copy.copy(self.grism_files)
grism_files.extend(new.grism_files)
self.grism_files = grism_files
# self.direct_files.extend(new.direct_files)
# self.grism_files.extend(new.grism_files)
if verbose:
print("Now we have {0:d} FLTs".format(self.N))
[docs] def compute_single_model(
self,
id,
center_rd=None,
mag=-99,
size=-1,
store=False,
spectrum_1d=None,
is_cgs=False,
get_beams=None,
in_place=True,
min_size=26,
psf_param_dict={},
):
"""
Compute model spectrum in all exposures
Parameters
----------
id : int
Catalog ID of the object to extract.
center_rd : optional, (float, float)
Extract based on RA/Dec rather than catalog ID.
mag : float
Magnitude of the object.
size : int
Half-size of the 2D spectrum to extract, along cross-dispersion
axis.
store : bool
Store the model in the `flt` object.
spectrum_1d : list
1D spectrum to fit.
is_cgs : bool
Spectrum is in cgs units.
get_beams : list
List of beam IDs to extract.s
in_place : bool
Compute the model in place.
min_size : int
Minimum size for the 2D spectrum to extract.
psf_param_dict : dict
Dictionary of PSF parameters to pass to the `~grizli.model.Beam`
objects.
Returns
-------
if get_beams:
out_beams : list
List of `~grizli.model.BeamCutout` objects.
else:
status : bool
True if the model was computed successfully
"""
out_beams = []
for flt in self.FLTs:
if flt.grism.parent_file in psf_param_dict:
psf_params = psf_param_dict[flt.grism.parent_file]
else:
psf_params = None
if center_rd is None:
x = y = None
else:
_rd = np.array(center_rd)[None, :]
x, y = flt.direct.wcs.all_world2pix(_rd, 0).flatten()
status = flt.compute_model_orders(
id=id,
x=x,
y=y,
verbose=False,
size=size,
compute_size=(size < 0),
mag=mag,
in_place=in_place,
store=store,
spectrum_1d=spectrum_1d,
is_cgs=is_cgs,
get_beams=get_beams,
psf_params=psf_params,
min_size=min_size,
)
out_beams.append(status)
if get_beams:
return out_beams
else:
return True
[docs] def compute_full_model(
self,
fit_info=None,
verbose=True,
store=False,
mag_limit=25,
coeffs=[1.2, -0.5],
cpu_count=0,
is_cgs=False,
model_kwargs={"compute_size": True},
):
"""
Compute continuum models of all sources in an FLT
Parameters
----------
fit_info : dict
Dictionary of the fitting information.
verbose : bool
Print verbose information to the terminal.
store : bool
Store the model in the `flt` object.
mag_limit : float
Faint limit of objects to compute
coeffs : list
Polynomial coefficients of the continuum model
cpu_count : int
Number of CPUs to use for parallel processing. If 0, then get
from `multiprocessing.cpu_count`.
is_cgs : bool
Spectral models are in cgs units
model_kwargs : dict
Keywords to pass to the
`~grizli.model.GrismFLT.compute_model_orders` method of the
`~grizli.model.GrismFLT` objects.
Returns
-------
Sets `object_dispersers` and `model` attributes on items in
`self.FLTs`
"""
if cpu_count <= 0:
cpu_count = np.maximum(mp.cpu_count() - 4, 1)
if fit_info is None:
bright = self.catalog["MAG_AUTO"] < mag_limit
ids = self.catalog["NUMBER"][bright]
mags = self.catalog["MAG_AUTO"][bright]
# Polynomial component
# xspec = np.arange(0.3, 5.35, 0.05)-1
xspec = np.arange(self.polyx[0], self.polyx[1], 0.05)
if len(self.polyx) > 2:
px0 = self.polyx[2]
else:
px0 = 1.0
yspec = [(xspec - px0) ** o * coeffs[o] for o in range(len(coeffs))]
xspec = (xspec) * 1.0e4
yspec = np.sum(yspec, axis=0)
fit_info = OrderedDict()
for id, mag in zip(ids, mags):
fit_info[id] = {"mag": mag, "spec": [xspec, yspec]}
is_cgs = False
t0_pool = time.time()
# Remove conf
for flt_i in self.FLTs:
if hasattr(flt_i, "conf"):
delattr(flt_i, "conf")
pool = mp.Pool(processes=cpu_count)
jobs = [
pool.apply_async(
_compute_model, (i, self.FLTs[i], fit_info, is_cgs, store, model_kwargs)
)
for i in range(self.N)
]
pool.close()
pool.join()
for res in jobs:
i, model, dispersers = res.get(timeout=1)
self.FLTs[i].object_dispersers = dispersers
self.FLTs[i].model = model
# Reload conf
for flt_i in self.FLTs:
if not hasattr(flt_i, "conf"):
flt_i.conf = grismconf.load_grism_config(flt_i.conf_file)
t1_pool = time.time()
if verbose:
print("Models computed - {0:.2f} sec.".format(t1_pool - t0_pool))
[docs] def get_beams(
self,
id,
size=10,
center_rd=None,
beam_id="A",
min_overlap=0.1,
min_valid_pix=10,
min_mask=0.01,
min_sens=0.08,
mask_resid=True,
get_slice_header=True,
show_exception=False,
):
"""
Extract 2D spectra "beams" from the GroupFLT exposures.
Parameters
----------
id : int
Catalog ID of the object to extract.
size : int
Half-size of the 2D spectrum to extract, along cross-dispersion
axis.
center_rd : optional, (float, float)
Extract based on RA/Dec rather than catalog ID.
beam_id : type
Name of the order to extract.
min_overlap : float
Fraction of the spectrum along wavelength axis that has one
or more valid pixels.
min_valid_pix : int
Minimum number of valid pixels (`beam.fit_mask == True`) in 2D
spectrum.
min_mask : float
Minimum factor relative to the maximum pixel value of the flat
f-lambda model where the 2D cutout data are considered good.
Passed through to `~grizli.model.BeamCutout`.
min_sens : float
See `~grizli.model.BeamCutout`.
mask_resid : bool
Passed to `~grizli.model.BeamCutout`.
get_slice_header : bool
Passed to `~grizli.model.BeamCutout`.
show_exception : bool
Print exceptions when they occur.
Returns
-------
beams : list
List of `~grizli.model.BeamCutout` objects.
"""
beams = self.compute_single_model(
id, center_rd=center_rd, size=size, store=False, get_beams=[beam_id]
)
out_beams = []
for flt, beam in zip(self.FLTs, beams):
try:
out_beam = model.BeamCutout(
flt=flt,
beam=beam[beam_id],
conf=flt.conf,
min_mask=min_mask,
min_sens=min_sens,
mask_resid=mask_resid,
get_slice_header=get_slice_header,
)
except:
# print('Except: get_beams')
if show_exception:
utils.log_exception(utils.LOGFILE, traceback)
continue
valid = out_beam.grism["SCI"] != 0
valid &= out_beam.fit_mask.reshape(out_beam.sh)
hasdata = (valid.sum(axis=0) > 0).sum()
if hasdata * 1.0 / out_beam.model.shape[1] < min_overlap:
continue
# Empty direct image?
if out_beam.beam.total_flux == 0:
continue
if out_beam.fit_mask.sum() < min_valid_pix:
continue
out_beams.append(out_beam)
return out_beams
[docs] def refine_list(
self,
ids=[],
mags=[],
poly_order=3,
mag_limits=[16, 24],
max_coeff=5,
ds9=None,
verbose=True,
fcontam=0.5,
wave=np.linspace(0.2, 2.5e4, 100),
):
"""
Refine contamination model for list of objects. Loops over `refine`.
Parameters
----------
ids : list
List of object IDs
mags : list
Magnitudes to to along with IDs. If `ids` and `mags` not
specified, then get the ID list from `self.catalog['MAG_AUTO']`.
poly_order : int
Order of the polynomial fit to the spectra.
mag_limits : [float, float]
Magnitude limits of objects to fit from `self.catalog['MAG_AUTO']`
when `ids` and `mags` not set.
max_coeff : float
Fit is considered bad when one of the coefficients is greater
than this value. See `refine`.
ds9 : `~grizli.ds9.DS9`, optional
Display the refined models to DS9 as they are computed.
verbose : bool
Print fit coefficients.
fcontam : float
Contamination weighting parameter.
wave : `~numpy.array`
Wavelength array for the polynomial fit.
Returns
-------
Updates `self.model` in place.
"""
if (len(ids) == 0) | (len(ids) != len(mags)):
bright = (self.catalog["MAG_AUTO"] < mag_limits[1]) & (
self.catalog["MAG_AUTO"] > mag_limits[0]
)
ids = self.catalog["NUMBER"][bright] * 1
mags = self.catalog["MAG_AUTO"][bright] * 1
so = np.argsort(mags)
ids, mags = ids[so], mags[so]
# wave = np.linspace(0.2,5.4e4,100)
poly_templates = utils.polynomial_templates(wave, order=poly_order, line=False)
for id, mag in zip(ids, mags):
self.refine(
id,
mag=mag,
poly_order=poly_order,
max_coeff=max_coeff,
size=30,
ds9=ds9,
verbose=verbose,
fcontam=fcontam,
templates=poly_templates,
)
[docs] def refine(
self,
id,
mag=-99,
poly_order=3,
size=30,
ds9=None,
verbose=True,
max_coeff=2.5,
fcontam=0.5,
templates=None,
):
"""
Fit polynomial to extracted spectrum of single object to use for contamination model.
Parameters
----------
id : int
Object ID to extract.
mag : float
Object magnitude. Determines which orders to extract; see
`~grizli.model.GrismFLT.compute_model_orders`.
poly_order : int
Order of the polynomial to fit.
size : int
Size of cutout to extract.
ds9 : `~grizli.ds9.DS9`, optional
Display the refined models to DS9 as they are computed.
verbose : bool
Print information about the fit
max_coeff : float
The script computes the implied flux of the polynomial template
at the pivot wavelength of the direct image filters. If this
flux is greater than `max_coeff` times the *observed* flux in the
direct image, then the polynomal fit is considered bad.
fcontam : float
Contamination weighting parameter.
templates : dict, optional
Precomputed template dictionary. If `None` then compute
polynomial templates with order `poly_order`.
Returns
-------
Updates `self.model` in place.
"""
beams = self.get_beams(
id,
size=size,
min_overlap=0.1,
get_slice_header=False,
min_mask=0.01,
min_sens=0.01,
mask_resid=True,
)
if len(beams) == 0:
return True
mb = MultiBeam(
beams,
fcontam=fcontam,
min_sens=0.01,
sys_err=0.03,
min_mask=0.01,
mask_resid=True,
)
if templates is None:
wave = np.linspace(0.9 * mb.wavef.min(), 1.1 * mb.wavef.max(), 100)
templates = utils.polynomial_templates(wave, order=poly_order, line=False)
try:
tfit = mb.template_at_z(
z=0,
templates=templates,
fit_background=True,
fitter="lstsq",
get_uncertainties=2,
)
except:
ret = False
return False
scale_coeffs = [
tfit["cfit"]["poly {0}".format(i)][0] for i in range(1 + poly_order)
]
xspec, ypoly = tfit["cont1d"].wave, tfit["cont1d"].flux
# Don't extrapolate
mb_waves = mb.wavef[mb.fit_mask]
mb_clip = (xspec > mb_waves.min()) & (xspec < mb_waves.max())
if mb_clip.sum() > 0:
ypoly[xspec < mb_waves.min()] = ypoly[mb_clip][0]
ypoly[xspec > mb_waves.max()] = ypoly[mb_clip][-1]
# Check where templates inconsistent with broad-band fluxes
xb = [
beam.direct.ref_photplam
if beam.direct["REF"] is not None
else beam.direct.photplam
for beam in beams
]
obs_flux = np.array([beam.beam.total_flux for beam in beams])
mod_flux = np.polynomial.Polynomial(scale_coeffs)(np.array(xb) / 1.0e4 - 1)
nonz = obs_flux != 0
if (
(np.abs(mod_flux / obs_flux)[nonz].max() > max_coeff)
| ((~np.isfinite(mod_flux / obs_flux)[nonz]).sum() > 0)
| (np.min(mod_flux[nonz]) < 0)
| ((~np.isfinite(ypoly)).sum() > 0)
):
if verbose:
cstr = " ".join(["{0:9.2e}".format(c) for c in scale_coeffs])
print("{0:>5d} mag={1:6.2f} {2} xx".format(id, mag, cstr))
return True
# Put the refined model into the full-field model
self.compute_single_model(
id,
mag=mag,
size=-1,
store=False,
spectrum_1d=[xspec, ypoly],
is_cgs=True,
get_beams=None,
in_place=True,
)
# Display the result?
if ds9:
flt = self.FLTs[0]
mask = flt.grism["SCI"] != 0
ds9.view((flt.grism["SCI"] - flt.model) * mask, header=flt.grism.header)
if verbose:
cstr = " ".join(["{0:9.2e}".format(c) for c in scale_coeffs])
print("{0:>5d} mag={1:6.2f} {2}".format(id, mag, cstr))
return True
# m2d = mb.reshape_flat(modelf)
# ############
# def old_refine(self, id, mag=-99, poly_order=1, size=30, ds9=None, verbose=True, max_coeff=2.5):
# """TBD
# """
# # Extract and fit beam spectra
# beams = self.get_beams(id, size=size, min_overlap=0.5, get_slice_header=False)
# if len(beams) == 0:
# return True
#
# mb = MultiBeam(beams)
# try:
# A, out_coeffs, chi2, modelf = mb.fit_at_z(poly_order=poly_order, fit_background=True, fitter='lstsq')
# except:
# return False
#
# # Poly template
# scale_coeffs = out_coeffs[mb.N*mb.fit_bg:mb.N*mb.fit_bg+mb.n_poly]
# xspec, yfull = mb.eval_poly_spec(out_coeffs)
#
# # Check where templates inconsistent with broad-band fluxes
# xb = [beam.direct.ref_photplam if beam.direct['REF'] is not None
# else beam.direct.photplam for beam in beams]
# fb = [beam.beam.total_flux for beam in beams]
# mb = np.polynomial.Polynomial(scale_coeffs)(np.array(xb)/1.e4-1)
#
# if (np.abs(mb/fb).max() > max_coeff) | (~np.isfinite(mb/fb).sum() > 0) | (np.min(mb) < 0):
# if verbose:
# print('{0} mag={1:6.2f} {2} xx'.format(id, mag, scale_coeffs))
#
# return True
#
# # Put the refined model into the full-field model
# self.compute_single_model(id, mag=mag, size=-1, store=False, spectrum_1d=[(xspec+1)*1.e4, yfull], is_cgs=True, get_beams=None, in_place=True)
#
# # Display the result?
# if ds9:
# flt = self.FLTs[0]
# mask = flt.grism['SCI'] != 0
# ds9.view((flt.grism['SCI'] - flt.model)*mask,
# header=flt.grism.header)
#
# if verbose:
# print('{0} mag={1:6.2f} {2}'.format(id, mag, scale_coeffs))
#
# return True
# #m2d = mb.reshape_flat(modelf)
[docs] def make_stack(
self,
id,
size=20,
target="grism",
skip=True,
fcontam=1.0,
scale=1,
save=True,
kernel="point",
pixfrac=1,
diff=True,
):
"""
Make drizzled 2D stack for a given object
Parameters
----------
id : int
Object ID number.
size : int
Size of the cutout.
target : str
Rootname for output files.
skip : bool
If True and the stack PNG file already exists, don't proceed.
fcontam : float
Contamination weighting parameter.
scale : float
Drizzle scale.
save : bool
Save the figure and FITS HDU to files with names like
>>> img_file = '{0}_{1:05d}.stack.png'.format(target, id)
>>> fits_file = '{0}_{1:05d}.stack.fits'.format(target, id)
kernel : str
Drizzle kernel.
pixfrac : float
Drizzle `pixfrac` parameter.
diff : bool
Plot residual in final stack panel.
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDU of the stacked spectra.
fig : `~matplotlib.figure.Figure`
Stack figure object.
"""
print(target, id)
if os.path.exists("{0}_{1:05d}.stack.png".format(target, id)) & skip:
return True
beams = self.get_beams(id, size=size, beam_id="A")
if len(beams) == 0:
print("id = {0}: No beam cutouts available.".format(id))
return None
mb = MultiBeam(beams, fcontam=fcontam, group_name=target)
hdu, fig = mb.drizzle_grisms_and_PAs(
fcontam=fcontam,
flambda=False,
size=size,
scale=scale,
kernel=kernel,
pixfrac=pixfrac,
diff=diff,
)
if save:
fig.savefig("{0}_{1:05d}.stack.png".format(target, id))
hdu.writeto("{0}_{1:05d}.stack.fits".format(target, id), overwrite=True)
return hdu, fig
[docs] def drizzle_grism_models(
self,
root="grism_model",
kernel="square",
scale=0.1,
pixfrac=1,
make_figure=True,
fig_xsize=10,
write_ctx=False
):
"""
Make model-subtracted drizzled images of each grism / PA
Parameters
----------
root : str
Rootname of the output files.
kernel : str
Drizzle kernel e.g., ('square', 'point').
scale : float
Drizzle `scale` parameter, pixel scale in arcsec.
pixfrac : float
Drizzle "pixfrac".
make_figure : bool
Make a figure of the output.
fig_xsize : float
Size of the figure.
write_ctx : bool
Write the context image to the FITS header.
"""
try:
from .utils import drizzle_array_groups
except:
from grizli.utils import drizzle_array_groups
# Loop through grisms and PAs
for g in self.PA:
for pa in self.PA[g]:
idx = self.PA[g][pa]
N = len(idx)
sci_list = []
for i in idx:
grism = self.FLTs[i].grism
if "MED" in grism.data:
sci_list.append(grism["SCI"] + grism["MED"])
else:
sci_list.append(grism["SCI"])
# sci_list = [self.FLTs[i].grism['SCI'] for i in idx]
clean_list = [
self.FLTs[i].grism["SCI"] - self.FLTs[i].model for i in idx
]
wht_list = [
(self.FLTs[i].grism['DQ'] == 0) / self.FLTs[i].grism['ERR']**2
for i in idx
]
for i in range(N):
mask = ~np.isfinite(wht_list[i])
wht_list[i][mask] = 0
wcs_list = [self.FLTs[i].grism.wcs for i in idx]
for i, ix in enumerate(idx):
if wcs_list[i]._naxis[0] == 0:
wcs_list[i]._naxis = self.FLTs[ix].grism.sh[::-1]
# Science array
outfile = "{0}-{1}-{2}_grism_sci.fits".format(root, g.lower(), pa)
print(outfile)
out = drizzle_array_groups(
sci_list,
wht_list,
wcs_list,
scale=scale,
kernel=kernel,
pixfrac=pixfrac,
)
outsci, _, _, outctx, header, outputwcs = out
header["FILTER"] = g
header["PA"] = pa
# Add invidual FLTs to header
for i, ix in enumerate(idx):
header[f'FLT{str(i+1).zfill(5)}'] = self.FLTs[ix].grism_file
pyfits.writeto(
outfile,
data=outsci,
header=header,
overwrite=True,
output_verify="fix",
)
if write_ctx:
pyfits.writeto(outfile.replace('sci', 'ctx'), data=outctx,
header=header, overwrite=True, output_verify='fix')
# Model-subtracted
outfile = "{0}-{1}-{2}_grism_clean.fits".format(root, g.lower(), pa)
print(outfile)
out = drizzle_array_groups(
clean_list,
wht_list,
wcs_list,
scale=scale,
kernel=kernel,
pixfrac=pixfrac,
)
outsci, _, _, header, outputwcs = out
header["FILTER"] = g
header["PA"] = pa
pyfits.writeto(
outfile,
data=outsci,
header=header,
overwrite=True,
output_verify="fix",
)
# Make figure
if make_figure:
with pyfits.open(outfile.replace("clean", "sci")) as img:
im = img[0].data * 1
im[im == 0] = np.nan
sh = im.shape
yp, xp = np.indices(sh)
mask = np.isfinite(im)
xmi = np.maximum(xp[mask].min() - 10, 0)
xma = np.minimum(xp[mask].max() + 10, sh[1])
ymi = np.maximum(yp[mask].min() - 10, 0)
yma = np.minimum(yp[mask].max() + 10, sh[0])
xsl = slice(xmi, xma)
ysl = slice(ymi, yma)
_dy = ysl.stop - ysl.start
_dx = xsl.stop - xsl.start
sh_aspect = _dy / _dx
vmi, vma = -0.05, 0.2
fig = plt.figure(figsize=[fig_xsize, fig_xsize / 2 * sh_aspect])
ax = fig.add_subplot(121)
ax.imshow(
im[ysl, xsl], origin="lower", cmap="gray_r", vmin=vmi, vmax=vma
)
# Clean
ax = fig.add_subplot(122)
with pyfits.open(outfile) as img:
im = img[0].data * 1
im[im == 0] = np.nan
ax.imshow(
im[ysl, xsl], origin="lower", cmap="gray_r", vmin=vmi, vmax=vma
)
for ax in fig.axes:
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.axis("off")
fig.tight_layout(pad=0.0)
fig.text(
0.5,
0.98,
outfile.split("_grism")[0],
color="k",
bbox=dict(facecolor="w", edgecolor="None"),
ha="center",
va="top",
transform=fig.transFigure,
)
fig.savefig(outfile.split("_clean")[0] + ".png", transparent=True)
plt.close(fig)
[docs] def subtract_sep_background(
self,
mask_threshold=3,
mask_iter=3,
bw=256,
bh=256,
fw=3,
fh=3,
revert=True,
**kwargs,
):
"""
Remove a 2D background from grism exposures with `sep.Background`.
Parameters
----------
mask_threshold : float
S/N threshold to mask sources.
mask_iter : int
Number of iterations to mask sources.
bw, bh : int
Background mesh width and height.
fw, fh : int
Filter width and height.
revert : bool
Add the background back to the science array.
"""
import sep
for flt in self.FLTs:
msg = f"subtract_sep_background: {flt.grism.parent_file} "
msg += f" mask_threshold={mask_threshold}"
msg += f" fw,fh={fw},{fh} bw,bh={bw},{bh}"
utils.log_comment(utils.LOGFILE, msg, verbose=True)
sci_i = (flt.grism.data["SCI"] - flt.model).astype(np.float32) * 1
if revert & ("BKG" in flt.grism.data):
sci_i += flt.grism.data["BKG"]
err_i = flt.grism.data["ERR"]
dq_i = flt.grism.data["DQ"]
ok = (sci_i != 0) & (err_i > 0) & np.isfinite(err_i + sci_i)
ok &= (dq_i & 1025) == 0
sci_i[~ok] = 0
ierr = 1 / err_i
ierr[~ok] = 0
bkg = 0
for _iter_i in range(mask_iter):
mask = ~(ok & ((sci_i - bkg) * ierr < mask_threshold))
back = sep.Background(sci_i, mask=mask, bw=bw, bh=bh, fw=fw, fh=fh)
bkg = back.back()
flt.grism.data["SCI"] -= bkg * ok
flt.grism.data["BKG"] = bkg * ok
flt.grism.header["BKGFW"] = fw, "sep background fw"
flt.grism.header["BKGFH"] = fh, "sep background fh"
flt.grism.header["BKGBW"] = bw, "sep background bw"
flt.grism.header["BKGBH"] = bh, "sep background bh"
[docs] def drizzle_full_wavelength(
self,
wave=1.4e4,
ref_header=None,
kernel="point",
pixfrac=1.0,
verbose=True,
offset=[0, 0],
fcontam=0.0,
):
"""
Drizzle FLT frames recentered at a specified wavelength
Script computes polynomial coefficients that define the dx and dy
offsets to a specific dispersed wavelengh relative to the reference
position and adds these to the SIP distortion keywords before
drizzling the input exposures to the output frame.
Parameters
----------
wave : float
Reference wavelength to center the output products
ref_header : `~astropy.io.fits.Header`
Reference header for setting the output WCS and image dimensions.
kernel : str, ('square' or 'point')
Drizzle kernel to use
pixfrac : float
Drizzle PIXFRAC (for `kernel` = 'point')
verbose : bool
Print information to terminal
offset : [float, float]
Offset to apply to the polynomial fit to the SIP coefficients.
fcontam : float
Contamination weighting parameter.
Returns
-------
sci, wht : `~np.ndarray`
Drizzle science and weight arrays with dimensions set in
`ref_header`.
"""
from astropy.modeling import models, fitting
import astropy.wcs as pywcs
# try:
# import drizzle
# if drizzle.__version__ != '1.12.99':
# # Not the fork that works for all input/output arrays
# raise(ImportError)
#
# #print('drizzle!!')
# from drizzle.dodrizzle import dodrizzle
# drizzler = dodrizzle
# dfillval = '0'
# except:
from drizzlepac import adrizzle
adrizzle.log.setLevel("ERROR")
drizzler = adrizzle.do_driz
dfillval = 0
# Quick check now for which grism exposures we should use
if wave < 1.1e4:
use_grism = "G102"
else:
use_grism = "G141"
# Get the configuration file
conf = None
for i in range(self.N):
if self.FLTs[i].grism.filter == use_grism:
conf = self.FLTs[i].conf
# Grism not found in list
if conf is None:
return False
# Compute field-dependent dispersion parameters
dydx_0_p = conf.conf["DYDX_A_0"]
dydx_1_p = conf.conf["DYDX_A_1"]
dldp_0_p = conf.conf["DLDP_A_0"]
dldp_1_p = conf.conf["DLDP_A_1"]
yp, xp = np.indices((1014, 1014)) # hardcoded for WFC3/IR
sk = 10 # don't need to evaluate at every pixel
dydx_0 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dydx_0_p)
dydx_1 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dydx_1_p)
dldp_0 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dldp_0_p)
dldp_1 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dldp_1_p)
# Inverse pixel offsets from the specified wavelength
dp = (wave - dldp_0) / dldp_1
i_x, i_y = 1, 0 # indexing offsets
dx = dp / np.sqrt(1 + dydx_1) + i_x
dy = dydx_0 + dydx_1 * dx + i_y
dx += offset[0]
dy += offset[1]
# Compute polynomial coefficients
p_init = models.Polynomial2D(degree=4)
# fit_p = fitting.LevMarLSQFitter()
fit_p = fitting.LinearLSQFitter()
p_dx = fit_p(p_init, xp[::sk, ::sk] - 507, yp[::sk, ::sk] - 507, -dx)
p_dy = fit_p(p_init, xp[::sk, ::sk] - 507, yp[::sk, ::sk] - 507, -dy)
# Output WCS
out_wcs = pywcs.WCS(ref_header, relax=True)
out_wcs.pscale = utils.get_wcs_pscale(out_wcs)
# Initialize outputs
shape = (ref_header["NAXIS2"], ref_header["NAXIS1"])
outsci = np.zeros(shape, dtype=np.float32)
outwht = np.zeros(shape, dtype=np.float32)
outctx = np.zeros(shape, dtype=np.int32)
# Loop through exposures
for i in range(self.N):
flt = self.FLTs[i]
if flt.grism.filter != use_grism:
continue
h = flt.grism.header.copy()
# Update SIP coefficients
for j, p in enumerate(p_dx.param_names):
key = "A_" + p[1:]
if key in h:
h[key] += p_dx.parameters[j]
else:
h[key] = p_dx.parameters[j]
for j, p in enumerate(p_dy.param_names):
key = "B_" + p[1:]
if key in h:
h[key] += p_dy.parameters[j]
else:
h[key] = p_dy.parameters[j]
line_wcs = pywcs.WCS(h, relax=True)
line_wcs.pscale = utils.get_wcs_pscale(line_wcs)
if not hasattr(line_wcs, "pixel_shape"):
line_wcs.pixel_shape = line_wcs._naxis1, line_wcs._naxis2
# Science and wht arrays
sci = flt.grism["SCI"] - flt.model
wht = 1 / (flt.grism["ERR"] ** 2)
scl = np.exp(-(fcontam * np.abs(flt.model) / flt.grism["ERR"]))
wht *= scl
wht[~np.isfinite(wht)] = 0
# Drizzle it
if verbose:
print(
"Drizzle {0} to wavelength {1:.2f}".format(
flt.grism.parent_file, wave
)
)
drizzler(
sci,
line_wcs,
wht,
out_wcs,
outsci,
outwht,
outctx,
1.0,
"cps",
1,
wcslin_pscale=line_wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
)
# Done!
return outsci, outwht
[docs] def find_source_along_trace(self, line_ra, line_dec):
"""
Given a sky position in the (artificial) dispersed image frame,
find sources that would disperse to that position.
Parameters
----------
line_ra, line_dec : float
RA, Dec of the source to find.
Returns
-------
None : TBD
"""
# Find an exposure that contains the position
conf = None
for flt in self.FLTs:
wcs_i = flt.grism.wcs
xw, yw = np.squeeze(wcs_i.all_world2pix([line_ra], [line_dec], 0))
sh = flt.grism.sh
if (xw > 0) & (yw > 0) & (xw < sh[1]) & (yw < sh[0]):
conf = flt.conf
break
if conf is None:
msg = f"No exposures disperse to {line_ra}, {line_dec}"
print(msg)
return None
# TBD
return None
# def replace_direct_image_cutouts(beams_file='', ref_image='gdn-100mas-f160w_drz_sci.fits', interp='poly5', cutout=200, background_func=utils.mode_statistic):
# """
# Replace "REF" extensions in a `beams.fits` file
#
# Parameters
# ----------
# beams_file : str
# Filename of a "beams.fits" file.
#
# ref_image : str or `~astropy.io.fits.HDUList`
# Filename or preloaded FITS file.
#
# interp : str
# Interpolation function to use for `~drizzlepac.astrodrizzle.ablot.do_blot`.
#
# cutout : int
# Make a slice of the `ref_image` with size [-cutout,+cutout] around
# the center position of the desired object before passing to `blot`.
#
# Returns
# -------
# beams_image : `~astropy.io.fits.HDUList`
# Image object with the "REF" extensions filled with the new blotted
# image cutouts.
#
# """
# from drizzlepac.astrodrizzle import ablot
#
# if isinstance(ref_image, pyfits.HDUList):
# ref_im = ref_image
# ref_image_filename = ref_image.filename()
# else:
# ref_im = pyfits.open(ref_image)
# ref_image_filename = ref_image
#
# ref_wcs = pywcs.WCS(ref_im[0].header, relax=True)
# ref_wcs.pscale = utils.get_wcs_pscale(ref_wcs)
#
# ref_photflam = ref_im[0].header['PHOTFLAM']
# ref_data = ref_im[0].data
# dummy_wht = np.ones_like(ref_im[0].data, dtype=np.float32)
#
# beams_image = pyfits.open(beams_file)
#
# beam_ra = beams_image[0].header['RA']
# beam_dec = beams_image[0].header['DEC']
#
# xy = np.asarray(np.round(ref_wcs.all_world2pix([beam_ra], [beam_dec], 0)),dtype=int).flatten()
#
# slx = slice(xy[0]-cutout, xy[0]+cutout)
# sly = slice(xy[1]-cutout, xy[1]+cutout)
#
# bkg_data = []
#
# for ie, ext in enumerate(beams_image):
# if 'EXTNAME' not in ext.header:
# continue
# elif ext.header['EXTNAME'] == 'REF':
# #break
#
# ext.header['REF_FILE'] = ref_image_filename
# for k in ['PHOTFLAM', 'PHOTPLAM']:
# ext.header[k] = ref_im[0].header[k]
#
# the_filter = utils.parse_filter_from_header(ref_im[0].header)
# ext.header['FILTER'] = ext.header['DFILTER'] = the_filter
#
# wcs_file = ext.header['GPARENT'].replace('.fits', '.{0:02}.wcs.fits'.format(ext.header['SCI_EXTN']))
# if os.path.exists(wcs_file):
# wcs_fobj = pyfits.open(wcs_file)
#
# ext_wcs = pywcs.WCS(ext.header, relax=True,
# fobj=wcs_fobj)
# # ext_wcs.pixel_shape = (wcs_fobj[0].header['CRPIX1']*2,
# # wcs_fobj[0].header['CRPIX2']*2)
# # try:
# # ext_wcs.wcs.cd = ext_wcs.wcs.pc
# # delattr(ext_wcs.wcs, 'pc')
# # except:
# # pass
# else:
# ext_wcs = pywcs.WCS(ext.header, relax=True)
#
# ext_wcs.pscale = utils.get_wcs_pscale(ext_wcs)
# blotted = ablot.do_blot(ref_data[sly, slx],
# ref_wcs.slice([sly, slx]),
# ext_wcs, 1, coeffs=True, interp=interp,
# sinscl=1.0, stepsize=10, wcsmap=None)
#
# if background_func is not None:
# seg_data = beams_image[ie+1].data
# msk = seg_data == 0
# #print(msk.shape, blotted.shape, seg_data.shape, ie)
# if msk.sum() > 0:
# if bkg_data is None:
# bkg_data = blotted[msk]
# else:
# bkg_data = np.append(bkg_data, blotted[msk])
#
# if msk.sum() > 0:
# blotted -= background_func(blotted[msk])
#
# ext.data = blotted*ref_photflam
#
# if bkg_data is not None:
# bkg_value = background_func(bkg_data)
# for i in range(self.N):
#
# return beams_image
[docs]class MultiBeam(GroupFitter):
def __init__(
self,
beams,
group_name=None,
fcontam=0.0,
psf=False,
polyx=[0.3, 2.5],
MW_EBV=0.0,
min_mask=0.01,
min_sens=0.08,
sys_err=0.0,
mask_resid=True,
verbose=True,
replace_direct=None,
restore_medfilt=False,
**kwargs,
):
"""
Tools for dealing with multiple `~.model.BeamCutout` instances
Parameters
----------
beams : list
List of `~.model.BeamCutout` objects.
group_name : str, None
Rootname to use for saved products. If None, then default to
'group'.
fcontam : float
Factor to use to downweight contaminated pixels. The pixel
inverse variances are scaled by the following weight factor when
evaluating chi-squared of a 2D fit,
`weight = np.exp(-(fcontam*np.abs(contam)*np.sqrt(ivar)))`
where `contam` is the contaminating flux and `ivar` is the initial
pixel inverse variance.
psf : bool
Fit an ePSF model to the direct image to use as the morphological
reference.
polyx : list
Polynomial coefficients to use for the
`~.model.BeamCutout.compute_model` method.
MW_EBV : float
Milky way foreground extinction.
min_mask : float
Minimum factor relative to the maximum pixel value of the flat
f-lambda model where the 2D cutout data are considered good.
Passed through to `~grizli.model.BeamCutout`.
min_sens : float
See `~grizli.model.BeamCutout`.
sys_err : float
Systematic error added in quadrature to the pixel variances:
`var_total = var_initial + (beam.sci*sys_err)**2`
mask_resid : bool
Mask residuals above a certain threshold.
verbose : bool
Print verbose information.
replace_direct : dict
Replace the direct image cutouts with a new image. Keys are
`ref_image`, `interp`, `cutout`, and `background_func`.
restore_medfilt : bool
Restore median-filtered direct image cutouts.
Attributes
----------
Asave : dict
Dictionary to store the results of the fit.
"""
if group_name is None:
self.group_name = "group"
else:
self.group_name = group_name
self.fcontam = fcontam
self.polyx = polyx
self.min_mask = min_mask
self.min_sens = min_sens
self.mask_resid = mask_resid
self.sys_err = sys_err
self.restore_medfilt = restore_medfilt
self.Asave = {}
if isinstance(beams, str):
self.load_master_fits(beams, verbose=verbose)
# check if isJWST
isJWST = prep.check_isJWST(beams)
# Auto-generate group_name from filename, e.g.,
# j100140p0130_00237.beams.fits > j100140p0130
if group_name is None:
self.group_name = beams.split("_")[0]
else:
if isinstance(beams[0], str):
# `beams` is list of strings
if "beams.fits" in beams[0]:
# Master beam files
# check if isJWST
isJWST = prep.check_isJWST(beams[0])
self.load_master_fits(beams[0], verbose=verbose)
for i in range(1, len(beams)):
b_i = MultiBeam(
beams[i],
group_name=group_name,
fcontam=fcontam,
psf=psf,
polyx=polyx,
MW_EBV=np.maximum(MW_EBV, 0),
sys_err=sys_err,
verbose=verbose,
min_mask=min_mask,
min_sens=min_sens,
mask_resid=mask_resid,
)
self.extend(b_i)
else:
# List of individual beam.fits files
self.load_beam_fits(beams)
else:
self.beams = beams
self.ra, self.dec = self.beams[0].get_sky_coords()
if MW_EBV < 0:
# Try to get MW_EBV from mastquery.utils
try:
import mastquery.utils
MW_EBV = mastquery.utils.get_mw_dust(self.ra, self.dec)
except:
try:
import mastquery.utils
MW_EBV = mastquery.utils.get_irsa_dust(self.ra, self.dec)
except:
MW_EBV = 0.0
self.MW_EBV = MW_EBV
self._set_MW_EBV(MW_EBV)
self._parse_beams(psf=psf)
self.apply_trace_shift()
self.Nphot = 0
self.is_spec = 1
if replace_direct is not None:
self.replace_direct_image_cutouts(**replace_direct)
def _set_MW_EBV(self, MW_EBV, R_V=utils.MW_RV):
"""
Initialize Galactic extinction
Parameters
----------
MW_EBV : float
Local E(B-V)
R_V : float
Relation between specific and total extinction,
``a_v = r_v * ebv``.
"""
for b in self.beams:
beam = b.beam
if beam.MW_EBV != MW_EBV:
beam.MW_EBV = MW_EBV
beam.init_galactic_extinction(MW_EBV, R_V=R_V)
beam.process_config()
b.flat_flam = b.compute_model(in_place=False, is_cgs=True)
@property
def N(self):
"""
Length of the beam list.
"""
return len(self.beams)
@property
def Ngrism(self):
"""
Dictionary containing number of exposures by grism.
"""
# Parse grisms & PAs
Ngrism = {}
for beam in self.beams:
if beam.grism.instrument == "NIRISS":
grism = beam.grism.pupil
else:
grism = beam.grism.filter
if grism not in Ngrism:
Ngrism[grism] = 0
Ngrism[grism] += 1
return Ngrism
@property
def grisms(self):
"""
Available grisms
"""
grisms = list(self.Ngrism.keys())
return grisms
@property
def PA(self):
"""
Available PAs in each grism
"""
_PA = {}
for g in self.Ngrism:
_PA[g] = {}
for i, beam in enumerate(self.beams):
if beam.grism.instrument == "NIRISS":
grism = beam.grism.pupil
else:
grism = beam.grism.filter
PA_i = beam.get_dispersion_PA(decimals=0)
if PA_i in _PA[grism]:
_PA[grism][PA_i].append(i)
else:
_PA[grism][PA_i] = [i]
return _PA
@property
def id(self):
"""
ID of the first beam in the list.
"""
return self.beams[0].id
def _parse_beams(self, psf=False):
"""
Derive properties of the beam list (grism, PA) and initialize
data arrays.
Parameters
----------
psf : bool
Use a PSF model for the fit.
"""
# Use WFC3 ePSF for the fit
self.psf_param_dict = None
if (psf > 0) & (self.beams[0].grism.instrument in ["WFC3", "ACS"]):
self.psf_param_dict = OrderedDict()
for ib, beam in enumerate(self.beams):
if beam.direct.data["REF"] is not None:
# Use REF extension. scale factors might be wrong
beam.direct.data["SCI"] = beam.direct.data["REF"]
new_err = np.ones_like(beam.direct.data["ERR"])
new_err *= utils.nmad(beam.direct.data["SCI"])
beam.direct.data["ERR"] = new_err
beam.direct.filter = beam.direct.ref_filter # 'F160W'
beam.direct.photflam = beam.direct.ref_photflam
beam.init_epsf(yoff=0.0, skip=psf * 1, N=4, get_extended=True)
# beam.compute_model = beam.compute_model_psf
# beam.beam.compute_model = beam.beam.compute_model_psf
beam.compute_model(use_psf=True)
m = beam.compute_model(in_place=False)
# beam.modelf = beam.model.flatten()
# beam.model = beam.modelf.reshape(beam.beam.sh_beam)
beam.flat_flam = beam.compute_model(in_place=False, is_cgs=True)
_p = beam.grism.parent_file
self.psf_param_dict[_p] = beam.beam.psf_params
self._parse_beam_arrays()
def _parse_beam_arrays(self):
"""
Parse and initialize various beam-related arrays and attributes.
"""
self.poly_order = None
self.shapes = [beam.model.shape for beam in self.beams]
self.Nflat = [np.prod(shape) for shape in self.shapes]
self.Ntot = np.sum(self.Nflat)
for b in self.beams:
if hasattr(b, "xp_mask"):
delattr(b, "xp_mask")
# Big array of normalized wavelengths (wave / 1.e4 - 1)
self.xpf = (
np.hstack(
[
np.dot(
np.ones((b.beam.sh_beam[0], 1)), b.beam.lam[None, :]
).flatten()
/ 1.0e4
for b in self.beams
]
)
- 1
)
# Flat-flambda model spectra
self.flat_flam = np.hstack([b.flat_flam for b in self.beams])
self.fit_mask = np.hstack([b.fit_mask * b.contam_mask for b in self.beams])
self.DoF = self.fit_mask.sum()
# systematic error
for i, b in enumerate(self.beams):
if hasattr(b, "has_sys_err"):
continue
sciu = b.scif.reshape(b.sh)
ivar = 1.0 / (1 / b.ivar + (self.sys_err * sciu) ** 2)
ivar[~np.isfinite(ivar)] = 0
b.ivar = ivar * 1
b.ivarf = b.ivar.flatten()
self.ivarf = np.hstack([b.ivarf for b in self.beams])
self.fit_mask &= self.ivarf >= 0
self.scif = np.hstack([b.scif for b in self.beams])
self.idf = np.hstack([b.scif * 0 + ib for ib, b in enumerate(self.beams)])
self.idf = np.asarray(self.idf, dtype=int)
# self.ivarf = 1./(1/self.ivarf + (self.sys_err*self.scif)**2)
self.ivarf[~np.isfinite(self.ivarf)] = 0
self.sivarf = np.sqrt(self.ivarf)
self.wavef = np.hstack([b.wavef for b in self.beams])
self.contamf = np.hstack([b.contam.flatten() for b in self.beams])
weightf = np.exp(-(self.fcontam * np.abs(self.contamf) * self.sivarf))
weightf[~np.isfinite(weightf)] = 0
self.weightf = weightf
self.fit_mask &= self.weightf > 0
self.slices = self._get_slices(masked=False)
self._update_beam_mask()
self.DoF = int((self.weightf * self.fit_mask).sum())
self.Nmask = np.sum([b.fit_mask.sum() for b in self.beams])
# Initialize background fit array
# self.A_bg = np.zeros((self.N, self.Ntot))
# i0 = 0
# for i in range(self.N):
# self.A_bg[i, i0:i0+self.Nflat[i]] = 1.
# i0 += self.Nflat[i]
self.A_bg = self._init_background(masked=False)
self.Asave = {}
self.A_bgm = self._init_background(masked=True)
self.init_poly_coeffs(poly_order=1)
self.ra, self.dec = self.beams[0].get_sky_coords()
[docs] def compute_exptime(self):
"""
Compute number of exposures and total exposure time for each grism.
"""
exptime = {}
nexposures = {}
for beam in self.beams:
if beam.grism.instrument == "NIRISS":
grism = beam.grism.pupil
else:
grism = beam.grism.filter
if grism in exptime:
exptime[grism] += beam.grism.exptime
nexposures[grism] += 1
else:
exptime[grism] = beam.grism.exptime
nexposures[grism] = 1
return nexposures, exptime
[docs] def extend(self, new, verbose=True):
"""
Concatenate `~grizli.multifit.MultiBeam` objects
Parameters
----------
new : `~grizli.multifit.MultiBeam`
Beam object containing new beams to add.
verbose : bool
Print summary of the change.
"""
self.beams.extend(new.beams)
self._parse_beams()
if verbose:
print("Add beams: {0}\n Now: {1}".format(new.Ngrism, self.Ngrism))
[docs] def write_master_fits(
self,
verbose=True,
get_hdu=False,
strip=True,
include_model=False,
get_trace_table=True,
):
"""
Store all beams in a single HDU
Parameters
----------
verbose : bool
Print the name of the output file.
get_hdu : bool
Return the `~astropy.io.fits.HDUList` object rather than writing
to a file.
strip : bool
Strip the `~astropy.io.fits.HDUList` object before writing.
include_model : bool
Include the model arrays in the output file.
get_trace_table : bool
Include the trace table in the output file.
Returns
-------
hdu : `~astropy.io.fits.HDUList`
Output HDUList object if `get_hdu=True`.
"""
hdu = pyfits.HDUList([pyfits.PrimaryHDU()])
rd = self.beams[0].get_sky_coords()
hdu[0].header["ID"] = (self.id, "Object ID")
hdu[0].header["RA"] = (rd[0], "Right Ascension")
hdu[0].header["DEC"] = (rd[1], "Declination")
exptime = {}
for g in self.Ngrism:
exptime[g] = 0.0
count = []
for ib, beam in enumerate(self.beams):
hdu_i = beam.write_fits(
get_hdu=True,
strip=strip,
include_model=include_model,
get_trace_table=get_trace_table,
)
hdu.extend(hdu_i[1:])
count.append(len(hdu_i) - 1)
hdu[0].header["FILE{0:04d}".format(ib)] = (
beam.grism.parent_file,
"Grism parent file",
)
hdu[0].header["GRIS{0:04d}".format(ib)] = (
beam.grism.filter,
"Grism element",
)
hdu[0].header["NEXT{0:04d}".format(ib)] = (
count[-1],
"Number of extensions",
)
try:
exptime[beam.grism.filter] += beam.grism.header["EXPTIME"]
except:
exptime[beam.grism.pupil] += beam.grism.header["EXPTIME"]
hdu[0].header["COUNT"] = (self.N, " ".join(["{0}".format(c) for c in count]))
for g in self.Ngrism:
hdu[0].header["T_{0}".format(g)] = (
exptime[g],
"Exposure time in grism {0}".format(g),
)
if get_hdu:
return hdu
outfile = "{0}_{1:05d}.beams.fits".format(self.group_name, self.id)
if verbose:
print(outfile)
hdu.writeto(outfile, overwrite=True)
[docs] def load_master_fits(self, beam_file, verbose=True):
"""
Load a "beams.fits" file.
Parameters
----------
beam_file : str
Filename of the "beams.fits" file.
verbose : bool
Print the name of the file being loaded.
"""
import copy
try:
utils.fetch_acs_wcs_files(beam_file)
except:
pass
if verbose:
print("load_master_fits: {0}".format(beam_file))
# check if isJWST
isJWST = prep.check_isJWST(beam_file)
hdu = pyfits.open(beam_file, lazy_load_hdus=False)
N = hdu[0].header["COUNT"]
Next = np.asarray(hdu[0].header.comments["COUNT"].split(), dtype=int)
i0 = 1
self.beams = []
for i in range(N):
key = "NEXT{0:04d}".format(i)
if key in hdu[0].header:
Next_i = hdu[0].header[key]
else:
Next_i = 6 # Assume doesn't have direct SCI/ERR cutouts
# Testing for multiprocessing
if True:
hducopy = hdu[i0 : i0 + Next_i]
else:
# print('Copy!')
hducopy = pyfits.HDUList(
[
hdu[i].__class__(
data=hdu[i].data * 1,
header=copy.deepcopy(hdu[i].header),
name=hdu[i].name,
)
for i in range(i0, i0 + Next_i)
]
)
beam = model.BeamCutout(
fits_file=hducopy,
min_mask=self.min_mask,
min_sens=self.min_sens,
mask_resid=self.mask_resid,
restore_medfilt=self.restore_medfilt,
)
self.beams.append(beam)
if verbose:
print(
"{0} {1} {2}".format(
i + 1, beam.grism.parent_file, beam.grism.filter
)
)
i0 += Next_i # 6#Next[i]
hdu.close()
[docs] def write_beam_fits(self, verbose=True):
"""
Write individual beam cutouts to FITS files.
Parameters
----------
verbose : bool
Print the name of the output file.
Returns
-------
outfiles : list
List of output filenames.
"""
outfiles = []
for beam in self.beams:
root = beam.grism.parent_file.split(".fits")[0]
outfile = beam.write_fits(root)
if verbose:
print("Wrote {0}".format(outfile))
outfiles.append(outfile)
return outfiles
[docs] def load_beam_fits(self, beam_list, conf=None, verbose=True):
"""
Load a list of individual beam cutouts.
Parameters
----------
beam_list : list
List of filenames to load.
conf : `~grizli.multifit.GroupFitter`
Configuration object to use for the beam cutouts.
verbose : bool
Print the name of the file being loaded.
"""
self.beams = []
for file in beam_list:
if verbose:
print(file)
beam = model.BeamCutout(
fits_file=file,
conf=conf,
min_mask=self.min_mask,
min_sens=self.min_sens,
mask_resid=self.mask_resid,
)
self.beams.append(beam)
[docs] def replace_segmentation_image_cutouts(self, ref_image="gdn-100mas-f160w_seg.fits"):
"""
Replace "REF" extensions in a `beams.fits` file
Parameters
----------
ref_image : str, `~astropy.io.fits.HDUList`, `~astropy.io.fits.ImageHDU`
Filename or preloaded FITS file.
Returns
-------
beams_image : `~astropy.io.fits.HDUList`
Image object with the "REF" extensions filled with the new blotted
image cutouts.
"""
if isinstance(ref_image, pyfits.HDUList):
ref_data = ref_image[0].data
ref_header = ref_image[0].header
ref_image_filename = ref_image.filename()
elif isinstance(ref_image, pyfits.ImageHDU) | isinstance(
ref_image, pyfits.PrimaryHDU
):
ref_data = ref_image.data
ref_header = ref_image.header
ref_image_filename = "HDU"
else:
with pyfits.open(ref_image) as ref_im:
ref_data = ref_im[0].data * 1
ref_header = ref_im[0].header.copy()
ref_image_filename = ref_image
ref_wcs = pywcs.WCS(ref_header, relax=True)
ref_wcs.pscale = utils.get_wcs_pscale(ref_wcs)
ref_data = ref_data.astype(np.float32)
for ib in range(self.N):
wcs_copy = self.beams[ib].direct.wcs
if hasattr(wcs_copy, "idcscale"):
if wcs_copy.idcscale is None:
delattr(wcs_copy, "idcscale")
in_data, in_wcs, out_wcs = ref_data, ref_wcs, wcs_copy
blot_seg = utils.blot_nearest_exact(
ref_data,
ref_wcs,
wcs_copy,
verbose=True,
stepsize=-1,
scale_by_pixel_area=False,
)
self.beams[ib].beam.set_segmentation(blot_seg)
[docs] def replace_direct_image_cutouts(
self,
ref_image="gdn-100mas-f160w_drz_sci.fits",
ext=0,
interp="poly5",
cutout=200,
background_func=np.median,
thumb_labels=None,
):
"""
Replace "REF" extensions in a `beams.fits` file
Parameters
----------
ref_image : str or `~astropy.io.fits.HDUList`
Filename or preloaded FITS file.
ext : int
Extension of the `ref_image` to use.
interp : str
Interpolation function to use for `~drizzlepac.astrodrizzle.ablot.do_blot`.
cutout : int
Make a slice of the `ref_image` with size [-cutout,+cutout] around
the center position of the desired object before passing to
`blot`.
background_func : function, None
If not `None`, compute local background with value from
`background_func(ref_image[cutout])`.
thumb_labels : list
List of thumbnail labels to replace.
Returns
-------
beams_image : `~astropy.io.fits.HDUList`
Image object with the "REF" extensions filled with the new blotted
image cutouts.
"""
from drizzlepac.astrodrizzle import ablot
if isinstance(ref_image, pyfits.HDUList):
ref_data = ref_image[0].data
ref_header = ref_image[0].header
ref_image_filename = ref_image.filename()
elif isinstance(ref_image, pyfits.ImageHDU) | isinstance(
ref_image, pyfits.PrimaryHDU
):
ref_data = ref_image.data
ref_header = ref_image.header
ref_image_filename = "HDU"
else:
with pyfits.open(ref_image)[ext] as ref_im:
ref_data = ref_im.data * 1
ref_header = ref_im.header.copy()
ref_image_filename = ref_image
if ref_data.dtype not in [np.float32, np.dtype(">f4")]:
ref_data = ref_data.astype(np.float32)
ref_wcs = pywcs.WCS(ref_header, relax=True)
ref_wcs.pscale = utils.get_wcs_pscale(ref_wcs)
if not hasattr(ref_wcs, "_naxis1") & hasattr(ref_wcs, "_naxis"):
ref_wcs._naxis1, ref_wcs._naxis2 = ref_wcs._naxis
if "PHOTPLAM" in ref_header:
ref_photplam = ref_header["PHOTPLAM"]
else:
ref_photplam = 1.0
if "PHOTFLAM" in ref_header:
ref_photflam = ref_header["PHOTFLAM"]
else:
ref_photflam = 1.0
try:
ref_filter = utils.parse_filter_from_header(ref_header)
except:
ref_filter = "N/A"
beam_ra, beam_dec = self.ra, self.dec
xy = np.asarray(
np.round(ref_wcs.all_world2pix([beam_ra], [beam_dec], 0)), dtype=int
).flatten()
sh = ref_data.shape
slx = slice(np.maximum(xy[0] - cutout, 0), np.minimum(xy[0] + cutout, sh[1]))
sly = slice(np.maximum(xy[1] - cutout, 0), np.minimum(xy[1] + cutout, sh[0]))
bkg_data = None
for ie in range(self.N):
wcs_copy = self.beams[ie].direct.wcs
if hasattr(wcs_copy, "idcscale"):
if wcs_copy.idcscale is None:
delattr(wcs_copy, "idcscale")
if not hasattr(wcs_copy, "_naxis1") & hasattr(wcs_copy, "_naxis"):
wcs_copy._naxis1, wcs_copy._naxis2 = wcs_copy._naxis
blotted = ablot.do_blot(
ref_data[sly, slx],
ref_wcs.slice([sly, slx]),
wcs_copy,
1,
coeffs=True,
interp=interp,
sinscl=1.0,
stepsize=10,
wcsmap=None,
)
if background_func is not None:
msk = self.beams[ie].beam.seg == 0
# print(msk.shape, blotted.shape, ie)
if msk.sum() > 0:
if bkg_data is None:
bkg_data = blotted[msk]
else:
bkg_data = np.append(bkg_data, blotted[msk])
if thumb_labels is None:
self.beams[ie].direct.data["REF"] = blotted * ref_photflam
self.beams[ie].direct.ref_photflam = ref_photflam
self.beams[ie].direct.ref_photplam = ref_photplam
self.beams[ie].direct.ref_filter = ref_filter
# self.beams[ie].direct.ref_photflam
self.beams[ie].beam.direct = blotted * ref_photflam
else:
for label in thumb_labels:
self.beams[ie].thumbs[label] = blotted * ref_photflam
if bkg_data is not None:
for ie in range(self.N):
bkg_value = background_func(bkg_data) * ref_photflam
if thumb_labels is None:
self.beams[ie].direct.data["REF"] -= bkg_value
else:
for label in thumb_labels:
self.beams[ie].thumbs[label] -= bkg_value
## Recompute total_flux attribute
for b in self.beams:
b.beam.set_segmentation(b.beam.seg)
[docs] def reshape_flat(self, flat_array):
"""
Reshape a flattened array into a list of 2D arrays.
Parameters
----------
flat_array : array-like
Flattened array to reshape.
Returns
-------
out : list
List of 2D arrays.
"""
out = []
i0 = 0
for ib in range(self.N):
im2d = flat_array[i0 : i0 + self.Nflat[ib]].reshape(self.shapes[ib])
out.append(im2d)
i0 += self.Nflat[ib]
return out
[docs] def init_poly_coeffs(self, flat=None, poly_order=1):
"""
Initialize polynomial coefficients for the continuum fit.
Parameters
----------
flat : array-like
Flattened array of the flat f-lambda model.
poly_order : int
Order of the polynomial to fit.
Returns
-------
None (updates `self.A_poly`, `self.n_poly`, `self.x_poly`).
"""
# Already done?
if poly_order < 0:
ok_poly = False
poly_order = 0
else:
ok_poly = True
if poly_order == self.poly_order:
return None
self.poly_order = poly_order
if flat is None:
flat = self.flat_flam
# Polynomial continuum arrays
self.A_poly = np.array(
[self.xpf ** order * flat for order in range(poly_order + 1)]
)
self.A_poly *= ok_poly
self.n_poly = poly_order + 1
self.x_poly = np.array(
[
(self.beams[0].beam.lam / 1.0e4 - 1) ** order
for order in range(poly_order + 1)
]
)
[docs] def eval_poly_spec(self, coeffs_full):
"""
Evaluate polynomial spectrum.
Parameters
----------
coeffs_full : array-like
Full list of polynomial coefficients.
Returns
-------
xspec, yfull : array-like
Arrays of wavelength and polynomial spectrum.
"""
xspec = np.arange(self.polyx[0], self.polyx[1], 0.05)
if len(self.polyx) > 2:
px0 = self.polyx[2]
else:
px0 = 1.0
i0 = self.N * self.fit_bg
scale_coeffs = coeffs_full[i0 : i0 + self.n_poly]
# yspec = [xspec**o*scale_coeffs[o] for o in range(self.poly_order+1)]
yfull = np.polynomial.Polynomial(scale_coeffs)(xspec - px0)
return xspec, yfull
[docs] def compute_model(
self,
id=None,
spectrum_1d=None,
is_cgs=False,
apply_sensitivity=True,
scale=None,
reset=True,
):
"""
Compute the dispersed 2D model for an assumed input spectrum
This is a wrapper around the
`grizli.model.GrismDisperser.compute_model` method.
Nothing returned, but the `model` and `modelf` attributes are
updated on the `~grizli.model.GrismDisperser` subcomponents of the
`beams` list.
Parameters
----------
id : str
Object ID.
spectrum_1d : array-like
1D spectrum to use for the model.
is_cgs : bool
Spectrum is in cgs units.
apply_sensitivity : bool
Apply the sensitivity function to the model.
scale : float
Scale factor for the model.
reset : bool
Reset the model to zero before computing the new model.
"""
for beam in self.beams:
beam.beam.compute_model(
id=id,
spectrum_1d=spectrum_1d,
is_cgs=is_cgs,
scale=scale,
reset=reset,
apply_sensitivity=apply_sensitivity,
)
beam.modelf = beam.beam.modelf
beam.model = beam.beam.modelf.reshape(beam.beam.sh_beam)
[docs] def compute_model_psf(self, id=None, spectrum_1d=None, is_cgs=False):
"""
Compute the dispersed 2D model for an assumed input spectrum and for
ePSF morphologies
This is a wrapper around the
`grizli.model.GrismDisperser.compute_model_psf` method.
Nothing returned, but the `model` and `modelf` attributes are
updated on the `~grizli.model.GrismDisperser` subcomponents of the
`beams` list.
Parameters
----------
id : str
Object ID.
spectrum_1d : array-like
1D spectrum to use for the model.
is_cgs : bool
Spectrum is in cgs units.
"""
for beam in self.beams:
beam.beam.compute_model_psf(id=id, spectrum_1d=spectrum_1d, is_cgs=is_cgs)
beam.modelf = beam.beam.modelf
beam.model = beam.beam.modelf.reshape(beam.beam.sh_beam)
[docs] def fit_at_z(
self, z=0.0, templates={}, fitter="nnls", fit_background=True, poly_order=0
):
"""
Fit the model at a given redshift.
Parameters
----------
z : float
Redshift to fit.
templates : dict
Dictionary of template spectra to fit.
fitter : str
Fitting method. Options are 'nnls' and 'lstsq'.
fit_background : bool
Fit a constant background.
poly_order : int
Order of the polynomial to fit.
Returns
-------
out_coeffs : array-like
Output coefficients of the fit.
"""
try:
import sklearn.linear_model
HAS_SKLEARN = True
except:
HAS_SKLEARN = False
import numpy.linalg
import scipy.optimize
# print 'xxx Init poly'
self.init_poly_coeffs(poly_order=poly_order)
# print 'xxx Init bg'
if fit_background:
self.fit_bg = True
A = np.vstack((self.A_bg, self.A_poly))
else:
self.fit_bg = False
A = self.A_poly * 1
NTEMP = len(templates)
A_temp = np.zeros((NTEMP, self.Ntot))
# print 'xxx Load templates'
for i, key in enumerate(templates.keys()):
NTEMP += 1
temp = templates[key] # .zscale(z, 1.)
if hasattr(temp, "flux_flam"):
# eazy-py Template object
spectrum_1d = [temp.wave * (1 + z), temp.flux_flam(z=z) / (1 + z)]
else:
spectrum_1d = [temp.wave * (1 + z), temp.flux / (1 + z)]
if z > 4:
try:
import eazy.igm
igm = eazy.igm.Inoue14()
igmz = igm.full_IGM(z, spectrum_1d[0])
spectrum_1d[1] *= igmz
# print('IGM')
except:
# No IGM
pass
i0 = 0
for ib in range(self.N):
beam = self.beams[ib]
lam_beam = beam.beam.lam_beam
if (temp.wave.min() * (1 + z) > lam_beam.max()) | (
temp.wave.max() * (1 + z) < lam_beam.min()
):
tmodel = 0.0
else:
tmodel = beam.compute_model(
spectrum_1d=spectrum_1d, in_place=False, is_cgs=True
) # /beam.beam.total_flux
A_temp[i, i0 : i0 + self.Nflat[ib]] = tmodel # .flatten()
i0 += self.Nflat[ib]
if NTEMP > 0:
A = np.vstack((A, A_temp))
ok_temp = np.sum(A, axis=1) > 0
out_coeffs = np.zeros(A.shape[0])
# LSTSQ coefficients
# print 'xxx Fitter'
fit_functions = {"lstsq": np.linalg.lstsq, "nnls": scipy.optimize.nnls}
if fitter in fit_functions:
# 'lstsq':
Ax = A[:, self.fit_mask][ok_temp, :].T
# Weight by ivar
Ax *= np.sqrt(self.ivarf[self.fit_mask][:, np.newaxis])
# print 'xxx lstsq'
# out = numpy.linalg.lstsq(Ax,y)
if fitter == "lstsq":
y = self.scif[self.fit_mask]
# Weight by ivar
y *= np.sqrt(self.ivarf[self.fit_mask])
try:
out = np.linalg.lstsq(Ax, y, rcond=utils.LSTSQ_RCOND)
except:
print(A.min(), Ax.min(), self.fit_mask.sum(), y.min())
raise ValueError
lstsq_coeff, residuals, rank, s = out
coeffs = lstsq_coeff
if fitter == "nnls":
if fit_background:
off = 0.04
y = self.scif[self.fit_mask] + off
y *= np.sqrt(self.ivarf[self.fit_mask])
coeffs, rnorm = scipy.optimize.nnls(Ax, y + off)
coeffs[: self.N] -= 0.04
else:
y = self.scif[self.fit_mask]
y *= np.sqrt(self.ivarf[self.fit_mask])
coeffs, rnorm = scipy.optimize.nnls(Ax, y)
# if fitter == 'bounded':
# if fit_background:
# off = 0.04
# y = self.scif[self.fit_mask]+off
# y *= self.ivarf[self.fit_mask]
#
# coeffs, rnorm = scipy.optimize.nnls(Ax, y+off)
# coeffs[:self.N] -= 0.04
# else:
# y = self.scif[self.fit_mask]
# y *= np.sqrt(self.ivarf[self.fit_mask])
#
# coeffs, rnorm = scipy.optimize.nnls(Ax, y)
#
# out = scipy.optimize.minimize(self.eval_trace_shift, shifts, bounds=bounds, args=args, method='Powell', tol=tol)
elif HAS_SKLEARN:
Ax = A[:, self.fit_mask][ok_temp, :].T
y = self.scif[self.fit_mask]
# Wieght by ivar
Ax *= np.sqrt(self.ivarf[self.fit_mask][:, np.newaxis])
y *= np.sqrt(self.ivarf[self.fit_mask])
clf = sklearn.linear_model.LinearRegression()
status = clf.fit(Ax, y)
coeffs = clf.coef_
out_coeffs[ok_temp] = coeffs
modelf = np.dot(out_coeffs, A)
chi2 = np.sum(
(self.weightf * (self.scif - modelf) ** 2 * self.ivarf)[self.fit_mask]
)
if fit_background:
poly_coeffs = out_coeffs[self.N : self.N + self.n_poly]
else:
poly_coeffs = out_coeffs[: self.n_poly]
self.y_poly = np.dot(poly_coeffs, self.x_poly)
# x_poly = self.x_poly[1,:]+1 = self.beams[0].beam.lam/1.e4
return A, out_coeffs, chi2, modelf
[docs] def parse_fit_outputs(self, z, templates, coeffs_full, A):
"""
Parse output from `fit_at_z`.
Parameters
----------
z : float
Redshift at which to evaluate the fits.
templates : list of `~grizli.utils.SpectrumTemplate` objects
Generated with, e.g., `~grizli.utils.load_templates`.
coeffs_full : `~np.ndarray`
Template fit coefficients
A : `~np.ndarray`
Matrix generated for fits and used for computing model 2D spectra:
>>> model_flat = np.dot(coeffs_full, A)
>>> # mb = MultiBeam(...)
>>> all_models = mb.reshape_flat(model_flat)
>>> m0 = all_models[0] # model for mb.beams[0]
Returns
-------
line_flux : dict
Line fluxes and uncertainties, in cgs units (erg/s/cm2)
covar : `~np.ndarray`
Covariance matrix for the fit coefficients
cont1d, line1d, model1d : `~grizli.utils.SpectrumTemplate`
Best-fit continuum, line, and full (continuum + line) templates
model_continuum : `~np.ndarray`
Flat array of the best fit 2D continuum
"""
from collections import OrderedDict
# Covariance matrix for line flux uncertainties
Ax = A[:, self.fit_mask]
ok_temp = (np.sum(Ax, axis=1) > 0) & (coeffs_full != 0)
Ax = Ax[ok_temp, :].T * 1 # A[:, self.fit_mask][ok_temp,:].T
Ax *= np.sqrt(self.ivarf[self.fit_mask][:, np.newaxis])
try:
# covar = np.matrix(np.dot(Ax.T, Ax)).I
covar = utils.safe_invert(np.dot(Ax.T, Ax))
covard = np.sqrt(covar.diagonal())
except:
N = ok_temp.sum()
covar = np.zeros((N, N))
covard = np.zeros(N) # -1.
covar_full = utils.fill_masked_covar(covar, ok_temp)
# Random draws from covariance matrix
# draws = np.random.multivariate_normal(coeffs_full[ok_temp], covar, size=500)
line_flux_err = coeffs_full * 0.0
line_flux_err[ok_temp] = covard
# Continuum fit
mask = np.isfinite(coeffs_full)
for i, key in enumerate(templates.keys()):
if key.startswith("line"):
mask[self.N * self.fit_bg + self.n_poly + i] = False
model_continuum = np.dot(coeffs_full * mask, A)
self.model_continuum = self.reshape_flat(model_continuum)
# model_continuum.reshape(self.beam.sh_beam)
# 1D spectrum
# Polynomial component
xspec, yspec = self.eval_poly_spec(coeffs_full)
model1d = utils.SpectrumTemplate(xspec * 1.0e4, yspec)
cont1d = model1d * 1
i0 = self.fit_bg * self.N + self.n_poly
line_flux = OrderedDict()
fscl = 1.0 # self.beams[0].beam.total_flux/1.e-17
line1d = OrderedDict()
for i, key in enumerate(templates.keys()):
temp_i = templates[key].zscale(z, coeffs_full[i0 + i])
model1d += temp_i
if not key.startswith("line"):
cont1d += temp_i
else:
line1d[key.split()[1]] = temp_i
line_flux[key.split()[1]] = np.array(
[coeffs_full[i0 + i] * fscl, line_flux_err[i0 + i] * fscl]
)
return line_flux, covar_full, cont1d, line1d, model1d, model_continuum
[docs] def fit_stars(
self,
poly_order=1,
fitter="nnls",
fit_background=True,
verbose=True,
make_figure=True,
fwhm=0,
**kwargs
):
"""
Fit stellar templates to the data.
Parameters
----------
poly_order : int
Polynomial order to fit.
fitter : str
Fitting method. Options are 'nnls' and 'lstsq'.
fit_background : bool
Fit a constant background.
verbose : bool
Print information to the terminal.
make_figure : bool
Generate a figure of the fit.
zoom : bool
Zoom in on the best-fit redshift. Not used.
delta_chi2_threshold : float
Threshold for computing the redshift grid. Not used.
zr : float
Redshift to fit. Not used.
dz : float
Redshift step size. Not used.
fwhm : float
FWHM of the line spread function.
prior : None
Not used.
templates : dict
Dictionary of template spectra to fit.
figsize : list
Figure size for the output plot. Not used.
fsps_templates : bool
Not used.
Returns
-------
fit_data : dict
Fit data.
fig : `~matplotlib.figure.Figure`
Figure object.
"""
# Polynomial fit
out = self.fit_at_z(
z=0.0,
templates={},
fitter="lstsq",
poly_order=3,
fit_background=fit_background,
)
A, coeffs, chi2_poly, model_2d = out
# Star templates
templates = utils.load_templates(fwhm=fwhm, stars=True)
NTEMP = len(templates)
key = list(templates)[0]
temp_i = {key: templates[key]}
out = self.fit_at_z(
z=0.0,
templates=temp_i,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
A, coeffs, chi2, model_2d = out
chi2 = np.zeros(NTEMP)
coeffs = np.zeros((NTEMP, coeffs.shape[0]))
chi2min = 1e30
iz = 0
best = key
for i, key in enumerate(list(templates)):
temp_i = {key: templates[key]}
out = self.fit_at_z(
z=0.0,
templates=temp_i,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
A, coeffs[i, :], chi2[i], model_2d = out
if chi2[i] < chi2min:
iz = i
chi2min = chi2[i]
best = key
if verbose:
print(
utils.NO_NEWLINE + " {0} {1:9.1f} ({2})".format(key, chi2[i], best)
)
# Best-fit
temp_i = {best: templates[best]}
out = self.fit_at_z(
z=0.0,
templates=temp_i,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
A, coeffs_full, chi2_best, model_full = out
# Continuum fit
mask = np.isfinite(coeffs_full)
for i, key in enumerate(templates.keys()):
if key.startswith("line"):
mask[self.N * self.fit_bg + self.n_poly + i] = False
model_continuum = np.dot(coeffs_full * mask, A)
self.model_continuum = self.reshape_flat(model_continuum)
xspec, yspec = self.eval_poly_spec(coeffs_full)
model1d = utils.SpectrumTemplate(xspec * 1.0e4, yspec)
cont1d = model1d * 1
i0 = self.fit_bg * self.N + self.n_poly
line_flux = OrderedDict()
fscl = 1.0 # self.beams[0].beam.total_flux/1.e-17
temp_i = templates[best].zscale(0, coeffs_full[i0])
model1d += temp_i
cont1d += temp_i
fit_data = OrderedDict()
fit_data["poly_order"] = poly_order
fit_data["fwhm"] = 0
fit_data["zbest"] = np.argmin(chi2)
fit_data["chibest"] = chi2_best
fit_data["chi_poly"] = chi2_poly
fit_data["zgrid"] = np.arange(NTEMP)
fit_data["prior"] = 1
fit_data["A"] = A
fit_data["coeffs"] = coeffs
fit_data["chi2"] = chi2
fit_data["DoF"] = self.DoF
fit_data["model_full"] = model_full
fit_data["coeffs_full"] = coeffs_full
fit_data["line_flux"] = {}
# fit_data['templates_full'] = templates
fit_data["model_cont"] = model_continuum
fit_data["model1d"] = model1d
fit_data["cont1d"] = cont1d
# return fit_data
fig = None
if make_figure:
fig = self.show_redshift_fit(fit_data)
# fig.savefig('fit.pdf')
return fit_data, fig
[docs] def fit_redshift(
self,
prior=None,
poly_order=1,
fwhm=1200,
make_figure=True,
zr=None,
dz=None,
verbose=True,
fit_background=True,
fitter="nnls",
delta_chi2_threshold=0.004,
zoom=True,
line_complexes=True,
templates={},
figsize=[8, 5],
fsps_templates=False,
):
"""
Fit redshifts for the object.
Parameters
----------
prior: list or array-like or None
A list or array-like object containing two elements:
- prior[0]: The x-coordinates for interpolation.
- prior[1]: The y-coordinates for interpolation.
Used to compute interpolated values over
the `zgrid` using `numpy.interp` when not `None`.
poly_order : int
Polynomial order to fit.
fwhm : float
FWHM of the line spread function.
make_figure : bool
Generate a figure of the fit.
zr : list or None
Redshift range to fit.
dz : list or None
Redshift step size.
verbose : bool
Print information to the terminal.
fit_background : bool
Fit a constant background.
fitter : str
Fitting method. Options are 'nnls' and 'lstsq'.
delta_chi2_threshold : float
Threshold for computing the redshift grid. Not used.
zoom : bool
Zoom in on the best-fit redshift.
line_complexes : bool
Include line complexes in the fit.
templates : dict
Dictionary of template spectra to fit.
figsize : list
Figure size for the output plot.
fsps_templates : bool
Use FSPS templates.
Returns
-------
fit_data : dict
Fit data.
fig : `~matplotlib.figure.Figure`
Figure object.
"""
from numpy.polynomial import Polynomial
if zr is None:
zr = [0.65, 1.6]
if dz is None:
dz = [0.005, 0.0004]
if zr in [0]:
stars = True
zr = [0, 0.01]
fitter = "nnls"
else:
stars = False
zgrid = utils.log_zgrid(zr, dz=dz[0])
NZ = len(zgrid)
# Polynomial fit
out = self.fit_at_z(
z=0.0,
templates={},
fitter="lstsq",
poly_order=3,
fit_background=fit_background,
)
A, coeffs, chi2_poly, model_2d = out
# Set up for template fit
if templates == {}:
templates = utils.load_templates(
fwhm=fwhm,
stars=stars,
line_complexes=line_complexes,
fsps_templates=fsps_templates,
)
else:
if verbose:
print("User templates! N={0} \n".format(len(templates)))
NTEMP = len(templates)
out = self.fit_at_z(
z=0.0,
templates=templates,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
A, coeffs, chi2, model_2d = out
chi2 = np.zeros(NZ)
coeffs = np.zeros((NZ, coeffs.shape[0]))
chi2min = 1e30
iz = 0
for i in range(NZ):
out = self.fit_at_z(
z=zgrid[i],
templates=templates,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
A, coeffs[i, :], chi2[i], model_2d = out
if chi2[i] < chi2min:
iz = i
chi2min = chi2[i]
if verbose:
print(
utils.NO_NEWLINE
+ " {0:.4f} {1:9.1f} ({2:.4f})".format(
zgrid[i], chi2[i], zgrid[iz]
)
)
print("First iteration: z_best={0:.4f}\n".format(zgrid[iz]))
# peaks
# chi2nu = (chi2.min()-chi2)/self.DoF
# indexes = utils.find_peaks((chi2nu+delta_chi2_threshold)*(chi2nu > - delta_chi2_threshold), threshold=0.3, min_dist=21)
chi2_rev = (chi2_poly - chi2) / self.DoF
if chi2_poly < (chi2.min() + 9):
chi2_rev = (chi2.min() + 16 - chi2) / self.DoF
chi2_rev[chi2_rev < 0] = 0
indexes = utils.find_peaks(chi2_rev, threshold=0.4, min_dist=9)
num_peaks = len(indexes)
if False:
plt.plot(zgrid, (chi2 - chi2.min()) / self.DoF)
plt.scatter(
zgrid[indexes], (chi2 - chi2.min())[indexes] / self.DoF, color="r"
)
# delta_chi2 = (chi2.max()-chi2.min())/self.DoF
# if delta_chi2 > delta_chi2_threshold:
if (num_peaks > 0) & (not stars) & zoom:
zgrid_zoom = []
for ix in indexes:
if (ix > 0) & (ix < len(chi2) - 1):
p = Polynomial.fit(
zgrid[ix - 1 : ix + 2], chi2[ix - 1 : ix + 2], deg=2
)
zi = p.deriv().roots()[0]
chi_i = p(zi)
zgrid_zoom.extend(
np.arange(zi - 2 * dz[0], zi + 2 * dz[0] + dz[1] / 10.0, dz[1])
)
# zgrid_zoom = utils.zoom_zgrid(zgrid, chi2/self.DoF,
# threshold=delta_chi2_threshold,
# factor=dz[0]/dz[1])
NZOOM = len(zgrid_zoom)
chi2_zoom = np.zeros(NZOOM)
coeffs_zoom = np.zeros((NZOOM, coeffs.shape[1]))
iz = 0
chi2min = 1.0e30
for i in range(NZOOM):
out = self.fit_at_z(
z=zgrid_zoom[i],
templates=templates,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
A, coeffs_zoom[i, :], chi2_zoom[i], model_2d = out
if chi2_zoom[i] < chi2min:
chi2min = chi2_zoom[i]
iz = i
if verbose:
print(
utils.NO_NEWLINE
+ "- {0:.4f} {1:9.1f} ({2:.4f}) {3:d}/{4:d}".format(
zgrid_zoom[i], chi2_zoom[i], zgrid_zoom[iz], i + 1, NZOOM
)
)
zgrid = np.append(zgrid, zgrid_zoom)
chi2 = np.append(chi2, chi2_zoom)
coeffs = np.append(coeffs, coeffs_zoom, axis=0)
so = np.argsort(zgrid)
zgrid = zgrid[so]
chi2 = chi2[so]
coeffs = coeffs[so, :]
if prior is not None:
# print('\n\nPrior!\n\n', chi2.min(), prior[1].min())
interp_prior = np.interp(zgrid, prior[0], prior[1])
chi2 += interp_prior
else:
interp_prior = None
print(" Zoom iteration: z_best={0:.4f}\n".format(zgrid[np.argmin(chi2)]))
# Best redshift
if not stars:
templates = utils.load_templates(
line_complexes=False, fwhm=fwhm, fsps_templates=fsps_templates
)
zbest = zgrid[np.argmin(chi2)]
ix = np.argmin(chi2)
chibest = chi2.min()
# Fit parabola
if (ix > 0) & (ix < len(chi2) - 1):
p = Polynomial.fit(zgrid[ix - 1 : ix + 2], chi2[ix - 1 : ix + 2], deg=2)
zbest = p.deriv().roots()[0]
chibest = p(zbest)
out = self.fit_at_z(
z=zbest,
templates=templates,
fitter=fitter,
poly_order=poly_order,
fit_background=fit_background,
)
A, coeffs_full, chi2_best, model_full = out
# Parse results
out2 = self.parse_fit_outputs(zbest, templates, coeffs_full, A)
line_flux, covar, cont1d, line1d, model1d, model_continuum = out2
# Output dictionary with fit parameters
fit_data = OrderedDict()
fit_data["poly_order"] = poly_order
fit_data["fwhm"] = fwhm
fit_data["zbest"] = zbest
fit_data["chibest"] = chibest
fit_data["chi_poly"] = chi2_poly
fit_data["zgrid"] = zgrid
fit_data["prior"] = interp_prior
fit_data["A"] = A
fit_data["coeffs"] = coeffs
fit_data["chi2"] = chi2
fit_data["DoF"] = self.DoF
fit_data["model_full"] = model_full
fit_data["coeffs_full"] = coeffs_full
fit_data["covar"] = covar
fit_data["line_flux"] = line_flux
# fit_data['templates_full'] = templates
fit_data["model_cont"] = model_continuum
fit_data["model1d"] = model1d
fit_data["cont1d"] = cont1d
fit_data["line1d"] = line1d
# return fit_data
fig = None
if make_figure:
fig = self.show_redshift_fit(fit_data, figsize=figsize)
# fig.savefig('fit.pdf')
return fit_data, fig
[docs] def run_individual_fits(self, z=0, templates={}):
"""
Run template fits on each *exposure* individually to evaluate
variance in line and continuum fits.
Parameters
----------
z : float
Redshift at which to evaluate the fit
templates : list of `~grizli.utils.SpectrumTemplate` objects
Generated with, e.g., `load_templates`.
Returns
-------
line_flux, line_err : dict
Dictionaries with the measured line fluxes and uncertainties for
each exposure fit.
coeffs_list : `~np.ndarray` [Nbeam x Ntemplate]
Raw fit coefficients
chi2_list, DoF_list : `~np.ndarray` [Nbeam]
Chi-squared and effective degrees of freedom for each separate fit
"""
# Fit on the full set of beams
out = self.fit_at_z(
z=z,
templates=templates,
fitter="nnls",
poly_order=self.poly_order,
fit_background=self.fit_bg,
)
A, coeffs_full, chi2_best, model_full = out
out2 = self.parse_fit_outputs(z, templates, coeffs_full, A)
line, covar, cont1d, line1d, model1d, model_continuum = out2
NB, NTEMP = len(self.beams), len(templates)
# Outputs
coeffs_list = np.zeros((NB, NTEMP))
chi2_list = np.zeros(NB)
DoF_list = np.zeros(NB)
line_flux = OrderedDict()
line_err = OrderedDict()
line_keys = list(line.keys())
for k in line_keys:
line_flux[k] = np.zeros(NB)
line_err[k] = np.zeros(NB)
# Generate separate MultiBeam objects for each individual beam
for i, b in enumerate(self.beams):
b_i = MultiBeam([b], fcontam=self.fcontam, group_name=self.group_name)
out_i = b_i.fit_at_z(
z=z,
templates=templates,
fitter="nnls",
poly_order=self.poly_order,
fit_background=self.fit_bg,
)
A_i, coeffs_i, chi2_i, model_full_i = out_i
# Parse fit information from individual fits
out2 = b_i.parse_fit_outputs(z, templates, coeffs_i, A_i)
line_i, covar_i, cont1d_i, line1d_i, model1d_i, model_continuum_i = out2
for k in line_keys:
line_flux[k][i] = line_i[k][0]
line_err[k][i] = line_i[k][1]
coeffs_list[i, :] = coeffs_i[-NTEMP:]
chi2_list[i] = chi2_i
DoF_list[i] = b_i.DoF
return line_flux, line_err, coeffs_list, chi2_list, DoF_list
[docs] def show_redshift_fit(self, fit_data, plot_flambda=True, figsize=[8, 5]):
"""
Show redshift fit.
Parameters
----------
fit_data : dict
Fit data generated by `fit_redshift`.
plot_flambda : bool
Plot in f-lambda units.
figsize : list
Figure size for the output plot.
Returns
-------
fig : `~matplotlib.figure.Figure`
Figure object.
"""
import matplotlib.gridspec
gs = matplotlib.gridspec.GridSpec(2, 1, height_ratios=[0.6, 1])
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(gs[0])
c2min = fit_data["chi2"].min()
scale_pz = True
if scale_pz:
scale_nu = c2min / self.DoF
scl_label = "_s"
else:
scale_nu = 1.0
scl_label = ""
# axz.plot(z, (chi2-chi2.min())/scale_nu, color='k')
# ax.plot(fit_data['zgrid'], fit_data['chi2']/self.DoF)
ax.plot(fit_data["zgrid"], (fit_data["chi2"] - c2min) / scale_nu)
ax.set_xlabel("z")
ax.set_ylabel(r"$\chi^2_\nu$, $\nu$={0:d}".format(self.DoF))
ax.set_ylim(-4, 27)
ax.set_ylabel(
r"$\Delta\chi^2{2}$ ({0:.0f}/$\nu$={1:d})".format(
c2min, self.DoF, scl_label
)
)
ax.set_yticks([1, 4, 9, 16, 25])
# for delta in [1,4,9]:
# ax.plot(fit_data['zgrid'],
# fit_data['zgrid']*0.+(c2min+delta)/self.DoF,
# color='{0:.2f}'.format(1-delta*1./10))
ax.plot(
fit_data["zgrid"],
(fit_data["chi2"] * 0 + fit_data["chi_poly"] - c2min) / scale_nu,
color="b",
linestyle="--",
alpha=0.8,
)
ax.set_xlim(fit_data["zgrid"].min(), fit_data["zgrid"].max())
ax.grid()
ax.set_title(
r"ID = {0:d}, $z_\mathrm{{grism}}$={1:.4f}".format(
self.beams[0].id, fit_data["zbest"]
)
)
ax = fig.add_subplot(gs[1])
ymax = 0
ymin = 1e10
continuum_fit = self.reshape_flat(fit_data["model_cont"])
line_fit = self.reshape_flat(fit_data["model_full"])
grisms = self.Ngrism.keys()
wfull = {}
ffull = {}
efull = {}
for grism in grisms:
wfull[grism] = []
ffull[grism] = []
efull[grism] = []
for ib in range(self.N):
beam = self.beams[ib]
clean = beam.grism["SCI"] - beam.contam
if self.fit_bg:
bg_i = fit_data["coeffs_full"][ib]
clean -= bg_i # background
else:
bg_i = 0.0
# ivar = 1./(1./beam.ivar + self.fcontam*beam.contam)
# ivar[~np.isfinite(ivar)] = 0
# New weight scheme
ivar = beam.ivar
weight = np.exp(-(self.fcontam * np.abs(beam.contam) * np.sqrt(ivar)))
wave, flux, err = beam.beam.optimal_extract(clean, ivar=ivar, weight=weight)
mwave, mflux, merr = beam.beam.optimal_extract(
line_fit[ib] - bg_i, ivar=ivar, weight=weight
)
flat = beam.flat_flam.reshape(beam.beam.sh_beam)
wave, fflux, ferr = beam.beam.optimal_extract(
flat, ivar=ivar, weight=weight
)
if plot_flambda:
ok = beam.beam.sensitivity > 0.1 * beam.beam.sensitivity.max()
wave = wave[ok]
fscl = 1.0 / 1.0e-19 # beam.beam.total_flux/1.e-17
flux = (flux * fscl / fflux)[ok] * beam.beam.scale
err = (err * fscl / fflux)[ok]
mflux = (mflux * fscl / fflux)[ok] * beam.beam.scale
ylabel = r"$f_\lambda\,/\,10^{-19}\,\mathrm{cgs}$"
else:
ylabel = "flux (e-/s)"
scl_region = np.isfinite(mflux)
if scl_region.sum() == 0:
continue
# try:
# okerr = np.isfinite(err) #& (np.abs(flux/err) > 0.2) & (err != 0)
# med_err = np.median(err[okerr])
#
# ymax = np.maximum(ymax,
# (mflux[scl_region][2:-2] + med_err).max())
# ymin = np.minimum(ymin,
# (mflux[scl_region][2:-2] - med_err).min())
# except:
# continue
# okerr = (err != 0) & (np.abs(flux/err) > 0.2)
okerr = np.isfinite(err)
ax.errorbar(
wave[okerr] / 1.0e4,
flux[okerr],
err[okerr],
alpha=0.15 + 0.2 * (self.N <= 2),
linestyle="None",
marker=".",
color="{0:.2f}".format(ib * 0.5 / self.N),
zorder=1,
)
ax.plot(wave[okerr] / 1.0e4, mflux[okerr], color="r", alpha=0.5, zorder=3)
if beam.grism.instrument == "NIRISS":
grism = beam.grism.pupil
else:
grism = beam.grism.filter
# for grism in grisms:
wfull[grism] = np.append(wfull[grism], wave[okerr])
ffull[grism] = np.append(ffull[grism], flux[okerr])
efull[grism] = np.append(efull[grism], err[okerr])
# Scatter direct image flux
if beam.direct.ref_photplam is None:
ax.scatter(
beam.direct.photplam / 1.0e4,
beam.beam.total_flux / 1.0e-19,
marker="s",
edgecolor="k",
color=GRISM_COLORS[grism],
alpha=0.2,
zorder=100,
s=100,
)
else:
ax.scatter(
beam.direct.ref_photplam / 1.0e4,
beam.beam.total_flux / 1.0e-19,
marker="s",
edgecolor="k",
color=GRISM_COLORS[grism],
alpha=0.2,
zorder=100,
s=100,
)
for grism in grisms:
if self.Ngrism[grism] > 1:
# binned
okb = (
np.isfinite(wfull[grism])
& np.isfinite(ffull[grism])
& np.isfinite(efull[grism])
)
so = np.argsort(wfull[grism][okb])
var = efull[grism] ** 2
N = int(np.ceil(self.Ngrism[grism] / 2) * 2) * 2
kernel = np.ones(N, dtype=float) / N
wht = 1 / var[okb][so]
fbin = nd.convolve(ffull[grism][okb][so] * wht, kernel)[N // 2 :: N]
wbin = nd.convolve(wfull[grism][okb][so] * wht, kernel)[N // 2 :: N]
# vbin = nd.convolve(var[okb][so], kernel**2)[N//2::N]
wht_bin = nd.convolve(wht, kernel)[N // 2 :: N]
vbin = nd.convolve(wht, kernel ** 2)[N // 2 :: N] / wht_bin ** 2
fbin /= wht_bin
wbin /= wht_bin
# vbin = 1./wht_bin
ax.errorbar(
wbin / 1.0e4,
fbin,
np.sqrt(vbin),
alpha=0.8,
linestyle="None",
marker=".",
color=GRISM_COLORS[grism],
zorder=2,
)
med_err = np.median(np.sqrt(vbin))
ymin = np.minimum(ymin, (fbin - 2 * med_err).min())
ymax = np.maximum(ymax, (fbin + 2 * med_err).max())
ymin = np.maximum(0, ymin)
ax.set_ylim(ymin - 0.2 * np.abs(ymax), 1.3 * ymax)
xmin, xmax = 1.0e5, 0
for g in GRISM_LIMITS:
if g in grisms:
xmin = np.minimum(xmin, GRISM_LIMITS[g][0])
xmax = np.maximum(xmax, GRISM_LIMITS[g][1])
# print g, xmin, xmax
ax.set_xlim(xmin, xmax)
ax.semilogx(subsx=[xmax])
# axc.set_xticklabels([])
# axc.set_xlabel(r'$\lambda$')
# axc.set_ylabel(r'$f_\lambda \times 10^{-19}$')
from matplotlib.ticker import MultipleLocator
ax.xaxis.set_major_locator(MultipleLocator(0.1))
labels = np.arange(np.ceil(xmin * 10), np.ceil(xmax * 10)) / 10.0
ax.set_xticks(labels)
ax.set_xticklabels(labels)
ax.grid()
# Label
ax.text(
0.03,
1.03,
("{0}".format(self.Ngrism))
.replace("'", "")
.replace("{", "")
.replace("}", ""),
ha="left",
va="bottom",
transform=ax.transAxes,
fontsize=10,
)
# ax.plot(wave/1.e4, wave/1.e4*0., linestyle='--', color='k')
ax.hlines(0, xmin, xmax, linestyle="--", color="k")
ax.set_xlabel(r"$\lambda$")
ax.set_ylabel(ylabel)
gs.tight_layout(fig, pad=0.1)
return fig
[docs] def drizzle_segmentation(
self, wcsobj=None, kernel="square", pixfrac=1, verbose=False
):
"""
Drizzle segmentation image from individual `MultiBeam.beams`.
Parameters
----------
wcsobj: `~astropy.wcs.WCS` or `~astropy.io.fits.Header`
Output WCS.
kernel: e.g., 'square', 'point', 'gaussian'
Drizzle kernel, see `~drizzlepac.adrizzle.drizzle`.
pixfrac: float
Drizzle 'pixfrac', see `~drizzlepac.adrizzle.drizzle`.
verbose: bool
Print status messages.
Returns
-------
drizzled_segm: `~numpy.ndarray`, type `~numpy.int64`.
Drizzled segmentation image, with image dimensions and
WCS defined in `wcsobj`.
"""
import numpy as np
import astropy.wcs as pywcs
import astropy.io.fits as pyfits
try:
from . import utils
except:
from grizli import multifit, utils
all_ids = [np.unique(beam.beam.seg) for beam in self.beams]
all_ids = np.unique(np.hstack(all_ids))[1:]
if isinstance(wcsobj, pyfits.Header):
wcs = pywcs.WCS(wcsobj)
wcs.pscale = utils.get_wcs_pscale(wcs)
else:
wcs = wcsobj
if not hasattr(wcs, "pscale"):
wcs.pscale = utils.get_wcs_pscale(wcs)
if verbose:
print("Drizzle ID={0:.0f} (primary)".format(self.id))
drizzled_segm = self.drizzle_segmentation_id(
id=self.id, wcsobj=wcsobj, kernel=kernel, pixfrac=pixfrac, verbose=verbose
)
for id in all_ids:
if int(id) == self.id:
continue
if verbose:
print("Drizzle ID={0:.0f}".format(id))
dseg_i = self.drizzle_segmentation_id(
id=id, wcsobj=wcsobj, kernel=kernel, pixfrac=pixfrac, verbose=False
)
new_seg = drizzled_segm == 0
drizzled_segm[new_seg] = dseg_i[new_seg]
return drizzled_segm
[docs] def drizzle_segmentation_id(
self, id=None, wcsobj=None, kernel="square", pixfrac=1, verbose=True
):
"""
Drizzle segmentation image for a single ID
Parameters
----------
id : int
ID to drizzle.
wcsobj : `~astropy.wcs.WCS` or `~astropy.io.fits.Header`
Output WCS.
kernel : e.g., 'square', 'point', 'gaussian'
Drizzle kernel, see `~drizzlepac.adrizzle.drizzle`.
pixfrac : float
Drizzle 'pixfrac', see `~drizzlepac.adrizzle.drizzle`.
verbose : bool
Print status messages.
Returns
-------
drizzled_segm : `~numpy.ndarray`, type `~numpy.int64`.
Drizzled segmentation image, with image dimensions and
WCS defined in `wcsobj`.
"""
import numpy as np
import astropy.wcs as pywcs
import astropy.io.fits as pyfits
try:
from . import utils
except:
from grizli import multifit, utils
# Can be either a header or WCS object
if isinstance(wcsobj, pyfits.Header):
wcs = pywcs.WCS(wcsobj)
wcs.pscale = utils.get_wcs_pscale(wcs)
else:
wcs = wcsobj
if not hasattr(wcs, "pscale"):
wcs.pscale = utils.get_wcs_pscale(wcs)
if id is None:
id = self.id
sci_list = [(beam.beam.seg == id) * 1.0 for beam in self.beams]
wht_list = [np.isfinite(beam.beam.seg) * 1.0 for beam in self.beams]
wcs_list = [beam.direct.wcs for beam in self.beams]
out = utils.drizzle_array_groups(
sci_list,
wht_list,
wcs_list,
outputwcs=wcs,
scale=0.1,
kernel=kernel,
pixfrac=pixfrac,
verbose=verbose,
)
drizzled_segm = (out[0] > 0) * id
return drizzled_segm
[docs] def drizzle_fit_lines(
self,
fit,
pline,
force_line=["Ha+NII", "Ha", "OIII", "Hb", "OII"],
save_fits=True,
mask_lines=True,
mask_sn_limit=3,
mask_4959=True,
verbose=True,
include_segmentation=True,
get_ir_psfs=True,
min_line_sn=4,
):
"""
Drizzle line maps from the individual beams.
Parameters
----------
fit : dict
Fit data generated by `fit_redshift`.
pline : list
List of line names to drizzle.
force_line : list
List of line names to always drizzle.
save_fits : bool
Save the drizzled line maps to FITS files.
mask_lines : bool
Mask pixels around other lines.
mask_sn_limit : float
Mask pixels around lines with S/N below this limit.
mask_4959 : bool
Mask OIII-4959 if detected.
verbose : bool
Print status messages.
include_segmentation : bool
Drizzle the segmentation image.
get_ir_psfs : bool
Get the IR PSFs from the `MultiBeam.beams`.
min_line_sn : float
Minimum S/N to drizzle a line.
Returns
-------
hdu_full : list
List of `~astropy.io.fits.HDUList` objects with the drizzled line maps.
"""
line_wavelengths, line_ratios = utils.get_line_wavelengths()
hdu_full = []
saved_lines = []
if ("cfit" in fit) & mask_4959:
if "line OIII" in fit["templates"]:
t_o3 = utils.load_templates(
fwhm=fit["templates"]["line OIII"].fwhm,
line_complexes=False,
stars=False,
full_line_list=["OIII-4959"],
continuum_list=[],
fsps_templates=False,
)
if "zbest" in fit:
z_driz = fit["zbest"]
else:
z_driz = fit["z"]
if "line_flux" in fit:
line_flux_dict = fit["line_flux"]
else:
line_flux_dict = OrderedDict()
for key in fit["cfit"]:
if key.startswith("line"):
line_flux_dict[key.replace("line ", "")] = fit["cfit"][key]
# Compute continuum model
if "cfit" in fit:
if "bg {0:03d}".format(self.N - 1) in fit["cfit"]:
for ib, beam in enumerate(self.beams):
key = "bg {0:03d}".format(ib)
self.beams[ib].background = fit["cfit"][key][0]
cont = fit["cont1d"]
for beam in self.beams:
beam.compute_model(spectrum_1d=[cont.wave, cont.flux], is_cgs=True)
if hasattr(self, "pscale"):
if self.pscale is not None:
scale = self.compute_scale_array(self.pscale, beam.wavef)
beam.beam.pscale_array = scale.reshape(beam.sh)
else:
beam.beam.pscale_array = 1.0
else:
beam.beam.pscale_array = 1.0
for line in line_flux_dict:
line_flux, line_err = line_flux_dict[line]
if line_err == 0:
continue
# Skip if min_line_sn = inf
if not np.isfinite(min_line_sn):
continue
if (line_flux / line_err > min_line_sn) | (line in force_line):
if verbose:
print(
"Drizzle line -> {0:4s} ({1:.2f} {2:.2f})".format(
line, line_flux / 1.0e-17, line_err / 1.0e-17
)
)
line_wave_obs = line_wavelengths[line][0] * (1 + z_driz)
if mask_lines:
for beam in self.beams:
beam.oivar = beam.ivar * 1
lam = beam.beam.lam_beam
if hasattr(beam.beam, "pscale_array"):
pscale_array = beam.beam.pscale_array
else:
pscale_array = 1.0
# another idea, compute a model for the line itself
# and mask relatively "contaminated" pixels from
# other lines
try:
lm = fit["line1d"][line]
sp = [lm.wave, lm.flux]
except:
key = "line " + line
lm = fit["templates"][key]
line_flux = fit["cfit"][key][0]
scl = line_flux / (1 + z_driz)
sp = [lm.wave * (1 + z_driz), lm.flux * scl]
# lm = fit['line1d'][line]
if (lm.wave.max() < lam.min()) | (lm.wave.min() > lam.max()):
continue
# sp = [lm.wave, lm.flux]
if line_flux > 0:
m = beam.compute_model(
spectrum_1d=sp, in_place=False, is_cgs=True
)
lmodel = m.reshape(beam.beam.sh_beam) * pscale_array
else:
lmodel = np.zeros(beam.beam.sh_beam)
# if lmodel.max() == 0:
# continue
if "cfit" in fit:
keys = fit["cfit"]
else:
keys = fit["line1d"]
beam.extra_lines = beam.contam * 0.0
for lkey in keys:
if not lkey.startswith("line"):
continue
key = lkey.replace("line ", "")
lf, le = line_flux_dict[key]
# Don't mask if the line missing or undetected
if lf <= 0: # | (lf < mask_sn_limit*le):
continue
if key != line:
try:
lm = fit["line1d"][lkey]
sp = [lm.wave, lm.flux]
except:
lm = fit["templates"][lkey]
scl = fit["cfit"][lkey][0] / (1 + z_driz)
sp = [lm.wave * (1 + z_driz), lm.flux * scl]
if (lm.wave.max() < lam.min()) | (
lm.wave.min() > lam.max()
):
continue
m = beam.compute_model(
spectrum_1d=sp, in_place=False, is_cgs=True
)
lcontam = m.reshape(beam.beam.sh_beam)
lcontam *= pscale_array
if lcontam.max() == 0:
# print beam.grism.parent_file, lkey
continue
beam.extra_lines += lcontam
# Only mask if line flux > 0
if line_flux > 0:
extra_msk = lcontam > mask_sn_limit * lmodel
extra_msk &= lcontam > 0
extra_msk &= lmodel > 0
beam.ivar[extra_msk] *= 0
# Subtract 4959
if (line == "OIII") & ("cfit" in fit) & mask_4959:
lm = t_o3["line OIII-4959"]
scl = fit["cfit"]["line OIII"][0] / (1 + z_driz)
scl *= 1.0 / (2.98 + 1)
sp = [lm.wave * (1 + z_driz), lm.flux * scl]
if (lm.wave.max() < lam.min()) | (
lm.wave.min() > lam.max()
):
continue
m = beam.compute_model(
spectrum_1d=sp, in_place=False, is_cgs=True
)
lcontam = m.reshape(beam.beam.sh_beam)
lcontam *= pscale_array
if lcontam.max() == 0:
continue
# print('Mask 4959!')
beam.extra_lines += lcontam
hdu = drizzle_to_wavelength(
self.beams,
ra=self.ra,
dec=self.dec,
wave=line_wave_obs,
fcontam=self.fcontam,
**pline,
)
if mask_lines:
for beam in self.beams:
beam.ivar = beam.oivar * 1
delattr(beam, "oivar")
hdu[0].header["REDSHIFT"] = (z_driz, "Redshift used")
# for e in [3,4,5,6]:
for e in [-4, -3, -2, -1]:
hdu[e].header["EXTVER"] = line
hdu[e].header["REDSHIFT"] = (z_driz, "Redshift used")
hdu[e].header["RESTWAVE"] = (
line_wavelengths[line][0],
"Line rest wavelength",
)
saved_lines.append(line)
if len(hdu_full) == 0:
hdu_full = hdu
hdu_full[0].header["NUMLINES"] = (1, "Number of lines in this file")
else:
hdu_full.extend(hdu[-4:])
hdu_full[0].header["NUMLINES"] += 1
# Make sure DSCI extension is filled. Can be empty for
# lines at the edge of the grism throughput
for f_i in range(hdu[0].header["NDFILT"]):
filt_i = hdu[0].header["DFILT{0:02d}".format(f_i + 1)]
if hdu["DWHT", filt_i].data.max() != 0:
hdu_full["DSCI", filt_i] = hdu["DSCI", filt_i]
hdu_full["DWHT", filt_i] = hdu["DWHT", filt_i]
li = hdu_full[0].header["NUMLINES"]
hdu_full[0].header["LINE{0:03d}".format(li)] = line
hdu_full[0].header["FLUX{0:03d}".format(li)] = (
line_flux,
"Line flux, 1e-17 erg/s/cm2",
)
hdu_full[0].header["ERR{0:03d}".format(li)] = (
line_err,
"Line flux err, 1e-17 erg/s/cm2",
)
if len(hdu_full) > 0:
hdu_full[0].header["HASLINES"] = (
" ".join(saved_lines),
"Lines in this file",
)
else:
hdu = drizzle_to_wavelength(
self.beams,
ra=self.ra,
dec=self.dec,
wave=np.median(self.beams[0].wave),
fcontam=self.fcontam,
**pline,
)
hdu_full = hdu[:-4]
hdu_full[0].header["REDSHIFT"] = (z_driz, "Redshift used")
hdu_full[0].header["NUMLINES"] = 0
hdu_full[0].header["HASLINES"] = " "
if include_segmentation:
line_wcs = pywcs.WCS(hdu_full[1].header)
segm = self.drizzle_segmentation(wcsobj=line_wcs)
seg_hdu = pyfits.ImageHDU(data=segm.astype(np.int32), name="SEG")
hdu_full.insert(1, seg_hdu)
if get_ir_psfs:
import grizli.galfit.psf
ir_beams = []
gr_filters = {
"G102": ["F105W"],
"G141": ["F105W", "F125W", "F140W", "F160W"],
}
show_filters = []
for gr in ["G102", "G141"]:
if gr in self.PA:
show_filters.extend(gr_filters[gr])
for pa in self.PA[gr]:
for i in self.PA[gr][pa]:
ir_beams.append(self.beams[i])
if len(ir_beams) > 0:
dp = grizli.galfit.psf.DrizzlePSF(
driz_hdu=hdu_full["DSCI"], beams=self.beams
)
for filt in np.unique(show_filters):
if verbose:
print("Get linemap PSF: {0}".format(filt))
psf = dp.get_psf(
ra=dp.driz_wcs.wcs.crval[0],
dec=dp.driz_wcs.wcs.crval[1],
filter=filt,
pixfrac=dp.driz_header["PIXFRAC"],
kernel=dp.driz_header["DRIZKRNL"],
wcs_slice=dp.driz_wcs,
get_extended=True,
verbose=False,
get_weight=False,
)
psf[1].header["EXTNAME"] = "DPSF"
psf[1].header["EXTVER"] = filt
hdu_full.append(psf[1])
if save_fits:
hdu_full.writeto(
"{0}_{1:05d}.line.fits".format(self.group_name, self.id),
overwrite=True,
output_verify="silentfix",
)
return hdu_full
[docs] def run_full_diagnostics(
self,
pzfit={},
pspec2={},
pline={},
force_line=["Ha+NII", "Ha", "OIII", "Hb", "OII"],
GroupFLT=None,
prior=None,
zoom=True,
verbose=True,
):
"""
Run full redshift fit and diagnostics on a `GroupFLT` object.
Parameters
----------
pzfit : dict
Parameters for the redshift fit.
pspec2 : dict
Parameters for the 2D spectrum drizzle.
pline : dict
Parameters for the line drizzle.
force_line : list
List of line names to always drizzle.
GroupFLT : `~grizli.multifit.GroupFLT`
Group object with the full set of objects to fit simultaneously.
prior : `~astropy.table.Table`
Prior catalog for the redshift fit.
zoom : bool
Zoom in on the object.
verbose : bool
Print status messages.
Returns
-------
fig : `~matplotlib.figure.Figure`
Figure object.
hdu2 : `~astropy.io.fits.HDUList`
Drizzled 2D spectrum.
hdu_full : `~astropy.io.fits.HDUList`
Drizzled line maps.
"""
import copy
# Defaults
pzfit_def, pspec2_def, pline_def = get_redshift_fit_defaults()
if pzfit == {}:
pzfit = pzfit_def
if pspec2 == {}:
pspec2 = pspec2_def
if pline == {}:
pline = pline_def
# Check that keywords allowed
for d, default in zip(
[pzfit, pspec2, pline], [pzfit_def, pspec2_def, pline_def]
):
for key in d:
if key not in default:
p = d.pop(key)
# Auto generate FWHM (in km/s) to use for line fits
if "fwhm" in pzfit:
fwhm = pzfit["fwhm"]
if pzfit["fwhm"] == 0:
if "G141" in self.Ngrism:
# WFC3/IR
fwhm = 1200
elif "G800L" in self.Ngrism:
# ACS/WFC
fwhm = 1400
elif "G280" in self.Ngrism:
# UVIS
fwhm = 1500
elif "GRISM" in self.Ngrism:
# WFIRST
fwhm = 350
elif "G150" in self.Ngrism:
# WFIRST
fwhm = 350
else:
fwhm = 700
# Auto generate delta-wavelength of 2D spectrum
if "dlam" in pspec2:
dlam = pspec2["dlam"]
if dlam == 0:
if "G141" in self.Ngrism:
dlam = 45
elif "G800L" in self.Ngrism:
dlam = 40
elif "G280" in self.Ngrism:
dlam = 18
elif "GRISM" in self.Ngrism:
dlam = 11
elif "G150" in self.Ngrism:
dlam = 11
else:
dlam = 25 # G102
# Redshift fit
zfit_in = copy.copy(pzfit)
zfit_in["fwhm"] = fwhm
zfit_in["prior"] = prior
zfit_in["zoom"] = zoom
zfit_in["verbose"] = verbose
if zfit_in["zr"] in [0]:
fit, fig = self.fit_stars(**zfit_in)
else:
fit, fig = self.fit_redshift(**zfit_in)
# Make sure model attributes are set to the continuum model
models = self.reshape_flat(fit["model_cont"])
for j in range(self.N):
self.beams[j].model = models[j] * 1
# 2D spectrum
spec_in = copy.copy(pspec2)
spec_in["fit"] = fit
spec_in["dlam"] = dlam
# fig2, hdu2 = self.redshift_fit_twod_figure(**spec_in)#, kwargs=spec2) #dlam=dlam, spatial_scale=spatial_scale, NY=NY)
fig2 = hdu2 = None
# Update master model
if GroupFLT is not None:
try:
ix = GroupFLT.catalog["NUMBER"] == self.beams[0].id
mag = GroupFLT.catalog["MAG_AUTO"][ix].data[0]
except:
mag = 22
sp = fit["cont1d"]
GroupFLT.compute_single_model(
id,
mag=mag,
size=-1,
store=False,
spectrum_1d=[sp.wave, sp.flux],
is_cgs=True,
get_beams=None,
in_place=True,
)
# 2D lines to drizzle
hdu_full = self.drizzle_fit_lines(
fit, pline, force_line=force_line, save_fits=True
)
fit["id"] = self.id
fit["fit_bg"] = self.fit_bg
fit["grism_files"] = [b.grism.parent_file for b in self.beams]
for item in ["A", "coeffs", "model_full", "model_cont"]:
if item in fit:
p = fit.pop(item)
# p = fit.pop('coeffs')
np.save("{0}_{1:05d}.zfit.npy".format(self.group_name, self.id), [fit])
fig.savefig("{0}_{1:05d}.zfit.png".format(self.group_name, self.id))
# fig2.savefig('{0}_{1:05d}.zfit.2D.png'.format(self.group_name, self.id))
# hdu2.writeto('{0}_{1:05d}.zfit.2D.fits'.format(self.group_name, self.id), overwrite=True, output_verify='silentfix')
label = "# id ra dec zbest "
data = "{0:7d} {1:.6f} {2:.6f} {3:.5f}".format(
self.id, self.ra, self.dec, fit["zbest"]
)
for grism in ["G800L", "G280", "G102", "G141", "GRISM"]:
label += " N{0}".format(grism)
if grism in self.Ngrism:
data += " {0:2d}".format(self.Ngrism[grism])
else:
data += " {0:2d}".format(0)
label += " chi2 DoF "
data += " {0:14.1f} {1:d} ".format(fit["chibest"], self.DoF)
for line in ["SII", "Ha", "OIII", "Hb", "Hg", "OII"]:
label += " {0} {0}_err".format(line)
if line in fit["line_flux"]:
flux = fit["line_flux"][line][0]
err = fit["line_flux"][line][1]
data += " {0:10.3e} {1:10.3e}".format(flux, err)
fp = open("{0}_{1:05d}.zfit.dat".format(self.group_name, self.id), "w")
fp.write(label + "\n")
fp.write(data + "\n")
fp.close()
fp = open("{0}_{1:05d}.zfit.beams.dat".format(self.group_name, self.id), "w")
fp.write("# file filter origin_x origin_y size padx pady bg\n")
for ib, beam in enumerate(self.beams):
msg = "{0:40s} {1:s} {2:5d} {3:5d} {4:5d} {5:5d} {6:5d}"
data = msg.format(
beam.grism.parent_file,
beam.grism.filter,
beam.direct.origin[0],
beam.direct.origin[1],
beam.direct.sh[0],
beam.direct.pad[1],
beam.direct.pad[0],
)
if self.fit_bg:
data += " {0:8.4f}".format(fit["coeffs_full"][ib])
else:
data += " {0:8.4f}".format(0.0)
fp.write(data + "\n")
fp.close()
# Save figures
plt_status = plt.rcParams["interactive"]
# if not plt_status:
# plt.close(fig)
# plt.close(fig2)
return fit, fig, fig2, hdu2, hdu_full
[docs] def apply_trace_shift(self, set_to_zero=False):
"""
Apply the current `beam.yoffset` to the trace of each beam.
Parameters
----------
set_to_zero : bool
Set beam.yoffset to zero
"""
indices = [[i] for i in range(self.N)]
if set_to_zero:
s0 = np.zeros(len(indices))
else:
s0 = [beam.beam.yoffset for beam in self.beams]
args = (self, indices, 0, False, False, True)
self.eval_trace_shift(s0, *args)
# Reset model profile for optimal extractions
for b in self.beams:
# b._parse_from_data()
if hasattr(b, "has_sys_err"):
delattr(b, "has_sys_err")
b._parse_from_data(**b._parse_params)
self._parse_beam_arrays()
[docs] def fit_trace_shift(
self,
split_groups=True,
max_shift=5,
tol=1.0e-2,
verbose=True,
lm=False,
fit_with_psf=False,
reset=False,
):
"""
Fit a global trace shift to the beams.
Parameters
----------
split_groups : bool
Split the fit into groups.
max_shift : float
Maximum allowed shift in pixels.
tol : float
Tolerance for the fit.
verbose : bool
Print status messages.
lm : bool
Use Levenberg-Marquardt optimization.
fit_with_psf : bool
Fit with the PSF model.
reset : bool
Reset the trace shifts to zero.
Returns
-------
shifts : `~numpy.ndarray`
Array of trace shifts.
out : `~numpy.ndarray`
Output from the optimization routine.
"""
from scipy.optimize import leastsq, minimize
if split_groups:
indices = []
for g in self.PA:
for p in self.PA[g]:
indices.append(self.PA[g][p])
else:
indices = [[i] for i in range(self.N)]
s0 = np.zeros(len(indices))
bounds = np.array([[-max_shift, max_shift]] * len(indices))
args = (self, indices, 0, lm, verbose, fit_with_psf)
if reset:
shifts = np.zeros(len(indices))
out = None
elif lm:
out = leastsq(
self.eval_trace_shift,
s0,
args=args,
Dfun=None,
full_output=0,
col_deriv=0,
ftol=1.49012e-08,
xtol=1.49012e-08,
gtol=0.0,
maxfev=0,
epsfcn=None,
factor=100,
diag=None,
)
shifts = out[0]
else:
out = minimize(
self.eval_trace_shift,
s0,
bounds=bounds,
args=args,
method="Powell",
tol=tol,
)
if out.x.shape == ():
shifts = [float(out.x)]
else:
shifts = out.x
# Apply to PSF if necessary
args = (self, indices, 0, lm, verbose, True)
self.eval_trace_shift(shifts, *args)
# Reset model profile for optimal extractions
for b in self.beams:
# b._parse_from_data()
b._parse_from_data(**b._parse_params)
# Needed for background modeling
if hasattr(b, "xp"):
delattr(b, "xp")
self._parse_beam_arrays()
self.initialize_masked_arrays()
return shifts, out
[docs] @staticmethod
def eval_trace_shift(shifts, self, indices, poly_order, lm, verbose, fit_with_psf):
"""
Evaluate the trace shifts.
Parameters
----------
shifts : `~numpy.ndarray`
Array of trace shifts.
self : `~grizli.multifit.MultiBeam`
Object to evaluate.
indices : list
List of indices to evaluate.
poly_order : int
Polynomial order to fit.
lm : bool
Use Levenberg-Marquardt optimization.
verbose : bool
Print status messages.
fit_with_psf : bool
Fit with the PSF model.
Returns
-------
chi2 : float
Chi-squared value.
"""
import scipy.ndimage as nd
for il, l in enumerate(indices):
for i in l:
beam = self.beams[i]
beam.beam.add_ytrace_offset(shifts[il])
if hasattr(self.beams[i].beam, "psf") & fit_with_psf:
# beam.model = nd.shift(beam.modelf.reshape(beam.sh_beam), (shifts[il], 0))
# This is slow, so run with fit_with_psf=False if possible
beam.init_epsf(
yoff=0, psf_params=beam.beam.psf_params # shifts[il],
)
beam.compute_model(use_psf=True)
m = beam.compute_model(in_place=False)
# beam.modelf = beam.model.flatten()
# beam.model = beam.modelf.reshape(beam.beam.sh_beam)
beam.flat_flam = beam.compute_model(in_place=False, is_cgs=True)
else:
# self.beams[i].beam.add_ytrace_offset(shifts[il])
# self.beams[i].compute_model(is_cgs=True)
beam.compute_model(use_psf=False)
self.flat_flam = np.hstack([b.beam.model.flatten() for b in self.beams])
self.poly_order = -1
self.init_poly_coeffs(poly_order=poly_order)
self.fit_bg = False
A = self.A_poly * 1
ok_temp = np.sum(A, axis=1) != 0
out_coeffs = np.zeros(A.shape[0])
y = self.scif
out = np.linalg.lstsq(A.T, y, rcond=utils.LSTSQ_RCOND)
lstsq_coeff, residuals, rank, s = out
coeffs = lstsq_coeff
out_coeffs = np.zeros(A.shape[0])
out_coeffs[ok_temp] = coeffs
modelf = np.dot(out_coeffs, A)
if lm:
# L-M, return residuals
if verbose:
print(
"{0} [{1}]".format(
utils.NO_NEWLINE,
" ".join(["{0:5.2f}".format(s) for s in shifts]),
)
)
return ((self.scif - modelf) * self.sivarf)[self.fit_mask]
chi2 = np.sum(((self.scif - modelf) ** 2 * self.ivarf)[self.fit_mask])
if verbose:
print(
"{0} [{1}] {2:6.2f}".format(
utils.NO_NEWLINE,
" ".join(["{0:5.2f}".format(s) for s in shifts]),
chi2 / self.DoF,
)
)
return chi2 / self.DoF
[docs] def drizzle_grisms_and_PAs(
self,
size=10,
fcontam=0,
flambda=False,
scale=1,
pixfrac=0.5,
kernel="square",
usewcs=False,
tfit=None,
diff=True,
grism_list=[
"G800L",
"G102",
"G141",
"GR150C",
"GR150R",
"F090W",
"F115W",
"F150W",
"F200W",
"F356W",
"F300M",
"F335M",
"F360M",
"F410M",
"F430M",
"F460M",
"F480M",
"F444W",
],
mask_segmentation=True,
reset_model=True,
make_figure=True,
fig_args=dict(
mask_segmentation=True, average_only=False, scale_size=1, cmap="viridis_r"
),
**kwargs,
):
"""
Make figure showing spectra at different orients/grisms
Parameters
----------
size : int
Image size in pixels
fcontam : float
Contamination parameter for drizzle.
flambda : bool
Convert to f-lambda units.
scale : float
Scale factor for drizzle.
pixfrac : float
Drizzle pixfrac parameter.
kernel : str
Drizzle kernel parameter.
usewcs : bool
Use WCS information for drizzle.
tfit : dict
Template fit results.
diff : bool
Show difference from the model.
grism_list : list
List of grisms to include.
mask_segmentation : bool
Mask segmentation regions.
reset_model : bool
Reset the model to a flat spectrum.
make_figure : bool
Make the figure.
fig_args : dict
Arguments for the figure.
Returns
-------
fig : `~matplotlib.figure.Figure`
Figure object.
"""
from matplotlib.ticker import MultipleLocator
# import pysynphot as S
if usewcs:
drizzle_function = drizzle_2d_spectrum_wcs
else:
drizzle_function = drizzle_2d_spectrum
if "zfit" in kwargs:
tfit = kwargs["zfit"]
NX = len(self.PA)
NY = 0
for g in self.PA:
NY = np.maximum(NY, len(self.PA[g]))
NY += 1
# keys = list(self.PA)
keys = []
for key in grism_list:
if key in self.PA:
keys.append(key)
if tfit is not None:
if "coeffs_full" in tfit:
bg = tfit["coeffs_full"][: self.N]
z_cont = tfit["zbest"]
else:
# fitting.GroupFitter
z_cont = tfit["z"]
bg = []
for k in tfit["cfit"]:
if k.startswith("bg "):
bg.append(tfit["cfit"][k][0])
bg = np.array(bg)
else:
# Fit background
try:
out = self.xfit_at_z(
z=0, templates={}, fitter="lstsq", poly_order=3, fit_background=True
)
bg = out[-3][: self.N]
except:
bg = [0] * self.N
for ib, beam in enumerate(self.beams):
beam.bg = bg[ib]
prim = pyfits.PrimaryHDU()
h0 = prim.header
h0["ID"] = (self.id, "Object ID")
h0["RA"] = (self.ra, "Right ascension")
h0["DEC"] = (self.dec, "Declination")
h0["ISFLAM"] = (flambda, "Pixels in f-lam units")
h0["FCONTAM"] = (fcontam, "Contamination parameter")
h0["NGRISM"] = (len(keys), "Number of grisms")
all_hdus = []
for ig, g in enumerate(keys):
all_beams = []
hdus = []
pas = list(self.PA[g].keys())
pas.sort()
h0["GRISM{0:03d}".format(ig + 1)] = (g, "Grism name")
h0["N" + g] = (len(pas), "Number of PAs for grism " + g)
for ipa, pa in enumerate(pas):
h0[g + "{0:02d}".format(ipa + 1)] = (pa, "PA")
beams = [self.beams[i] for i in self.PA[g][pa]]
all_beams.extend(beams)
# dlam = np.ceil(np.diff(beams[0].beam.lam)[0])*scale
dlam = GRISM_LIMITS[g][2] * scale
data = [beam.grism["SCI"] - beam.contam - beam.bg for beam in beams]
hdu = drizzle_function(
beams,
data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale,
NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0,
ds9=None,
mask_segmentation=mask_segmentation,
)
hdu[0].header["RA"] = (self.ra, "Right ascension")
hdu[0].header["DEC"] = (self.dec, "Declination")
hdu[0].header["GRISM"] = (g, "Grism")
hdu[0].header["PA"] = (pa, "Dispersion PA")
hdu[0].header["ISFLAM"] = (flambda, "Pixels in f-lam units")
hdu[0].header["CONF"] = (
beams[0].beam.conf.conf_file,
"Configuration file",
)
hdu[0].header["DLAM0"] = (
np.median(np.diff(beams[0].wave)),
"Native dispersion per pix",
)
# Contam
data = [beam.contam for beam in beams]
hdu_contam = drizzle_function(
beams,
data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale,
NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0,
ds9=None,
mask_segmentation=mask_segmentation,
)
hdu_contam[1].header["EXTNAME"] = "CONTAM"
hdu.append(hdu_contam[1])
# Continuum model
if tfit is not None:
m = tfit["cont1d"]
for beam in beams:
beam.compute_model(spectrum_1d=[m.wave, m.flux], is_cgs=True)
else:
if reset_model:
# simple flat spectrum
for beam in beams:
beam.compute_model()
data = []
for beam in beams:
if hasattr(beam.beam, "pscale_array"):
data.append(beam.beam.model * beam.beam.pscale_array)
else:
data.append(beam.beam.model)
hdu_model = drizzle_function(
beams,
data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale,
NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0,
ds9=None,
mask_segmentation=mask_segmentation,
)
hdu_model[1].header["EXTNAME"] = "MODEL"
if tfit is not None:
hdu_model[1].header["CONTIN1D"] = (True, "Model is fit continuum")
hdu_model[1].header["REDSHIFT"] = (
z_cont,
"Redshift of the continuum spectrum",
)
else:
hdu_model[1].header["CONTIN1D"] = (False, "Model is fit continuum")
hdu.append(hdu_model[1])
# Line kernel
if not usewcs:
h = hdu[1].header
# header keywords scaled to um
toA = 1.0e4
gau = utils.SpectrumTemplate(
central_wave=h["CRVAL1"] * toA, fwhm=h["CD1_1"] * toA
)
if reset_model:
for beam in beams:
beam.compute_model(
spectrum_1d=[gau.wave, gau.flux], is_cgs=True
)
data = [beam.beam.model for beam in beams]
h_kern = drizzle_function(
beams,
data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale,
NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0,
fill_wht=True,
ds9=None,
mask_segmentation=mask_segmentation,
)
kern = h_kern[1].data[
:, h["CRPIX1"] - 1 - size : h["CRPIX1"] - 1 + size
]
hdu_kern = pyfits.ImageHDU(
data=kern, header=h_kern[1].header, name="KERNEL"
)
hdu.append(hdu_kern)
else:
hdu["DSCI"].header["EXTNAME"] = "KERNEL"
# Pull out zeroth extension
for k in hdu[0].header:
hdu[1].header[k] = hdu[0].header[k]
for e in hdu[1:]:
e.header["EXTVER"] = "{0},{1}".format(g, pa)
hdus.append(hdu[1:])
# Stack of each grism
data = [beam.grism["SCI"] - beam.contam - beam.bg for beam in all_beams]
hdu = drizzle_function(
all_beams,
data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale,
NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=fcontam,
ds9=None,
mask_segmentation=mask_segmentation,
)
hdu[0].header["RA"] = (self.ra, "Right ascension")
hdu[0].header["DEC"] = (self.dec, "Declination")
hdu[0].header["GRISM"] = (g, "Grism")
hdu[0].header["ISFLAM"] = (flambda, "Pixels in f-lam units")
hdu[0].header["CONF"] = (beams[0].beam.conf.conf_file, "Configuration file")
hdu[0].header["DLAM0"] = (
np.median(np.diff(beams[0].wave)),
"Native dispersion per pix",
)
# Full continuum model
if tfit is not None:
if diff > 1:
m = tfit["line1d"]
else:
m = tfit["cont1d"]
for beam in all_beams:
beam.compute_model(spectrum_1d=[m.wave, m.flux], is_cgs=True)
else:
if reset_model:
for beam in all_beams:
beam.compute_model()
# data = [beam.beam.model for beam in all_beams]
data = []
for beam in all_beams:
if hasattr(beam.beam, "pscale_array"):
data.append(beam.beam.model * beam.beam.pscale_array)
else:
data.append(beam.beam.model)
hdu_model = drizzle_function(
all_beams,
data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale,
NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=fcontam,
ds9=None,
mask_segmentation=mask_segmentation,
)
hdu_model[1].header["EXTNAME"] = "MODEL"
if tfit is not None:
hdu_model[1].header["CONTIN1D"] = (True, "Model is fit continuum")
hdu_model[1].header["REDSHIFT"] = (
z_cont,
"Redshift of the continuum spectrum",
)
else:
hdu_model[1].header["CONTIN1D"] = (False, "Model is fit continuum")
hdu.append(hdu_model[1])
# Full kernel
h = hdu[1].header
# gau = S.GaussianSource(1.e-17, h['CRVAL1'], h['CD1_1']*1)
toA = 1.0e4
# gau = S.GaussianSource(1., h['CRVAL1']*toA, h['CD1_1']*toA)
gau = utils.SpectrumTemplate(
central_wave=h["CRVAL1"] * toA, fwhm=h["CD1_1"] * toA
)
if reset_model:
for beam in all_beams:
beam.compute_model(spectrum_1d=[gau.wave, gau.flux], is_cgs=True)
data = [beam.beam.model for beam in all_beams]
h_kern = drizzle_function(
all_beams,
data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale,
NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0,
fill_wht=True,
ds9=None,
mask_segmentation=mask_segmentation,
)
kern = h_kern[1].data[
:, int(h["CRPIX1"]) - 1 - size : int(h["CRPIX1"]) - 1 + size
]
hdu_kern = pyfits.ImageHDU(
data=kern, header=h_kern[1].header, name="KERNEL"
)
hdu.append(hdu_kern)
# Pull out zeroth extension
for k in hdu[0].header:
hdu[1].header[k] = hdu[0].header[k]
for e in hdu[1:]:
e.header["EXTVER"] = "{0}".format(g)
hdus.append(hdu[1:])
all_hdus.extend(hdus)
output_hdu = pyfits.HDUList([prim])
for hdu in all_hdus:
output_hdu.extend(hdu)
if make_figure:
fig = show_drizzle_HDU(output_hdu, diff=diff, **fig_args)
return output_hdu, fig
else:
return output_hdu # all_hdus
[docs] def flag_with_drizzled(
self, hdul, sigma=4, update=True, interp="nearest", verbose=True
):
"""
Update `MultiBeam` masks based on the blotted drizzled combined image
[in progress ... xxx]
Parameters
----------
hdul : `~astropy.io.fits.HDUList`
FITS HDU list output from `drizzle_grisms_and_PAs` or read from a
`stack.fits` file.
sigma : float
Residual threshold to flag.
update : bool
Update the mask.
interp : str
Interpolation method for `~drizzlepac.ablot`.
verbose : bool
Print status messages.
Returns
-------
Updates the individual `fit_mask` attributes of the individual beams
if `update==True`.
"""
try:
from drizzle.doblot import doblot
blotter = doblot
except:
from drizzlepac import ablot
blotter = ablot.do_blot
# Read the drizzled arrays
Ng = hdul[0].header["NGRISM"]
ref_wcs = {}
ref_data = {}
flag_grism = {}
for i in range(Ng):
g = hdul[0].header["GRISM{0:03d}".format(i + 1)]
ref_wcs[g] = pywcs.WCS(hdul["SCI", g].header)
ref_wcs[g].pscale = utils.get_wcs_pscale(ref_wcs[g])
ref_data[g] = hdul["SCI", g].data
flag_grism[g] = hdul[0].header["N{0}".format(g)] > 1
# Do the masking
for i, beam in enumerate(self.beams):
g = beam.grism.filter
if not flag_grism[g]:
continue
beam_header, flt_wcs = beam.full_2d_wcs()
blotted = blotter(
ref_data[g],
ref_wcs[g],
flt_wcs,
1,
coeffs=True,
interp=interp,
sinscl=1.0,
stepsize=10,
wcsmap=None,
)
resid = beam.grism["SCI"] - beam.contam - blotted
resid *= np.sqrt(beam.ivar)
blot_mask = (blotted != 0) & (np.abs(resid) < sigma)
if verbose:
print(
"Beam {0:>3d}: {1:>4d} new masked pixels".format(
i,
beam.fit_mask.sum()
- (beam.fit_mask & blot_mask.flatten()).sum(),
)
)
if update:
beam.fit_mask &= blot_mask.flatten()
if update:
self._parse_beams()
self.initialize_masked_arrays()
[docs] def oned_spectrum(
self,
tfit=None,
get_contam=True,
get_background=False,
masked_model=None,
**kwargs,
):
"""
Compute full 1D spectrum with optional best-fit model
Parameters
----------
tfit : dict
Output of `~grizli.fitting.mb.template_at_z`.
get_contam : bool
Extract the contamination model.
get_background : bool
Extract the background model.
masked_model : `~numpy.ndarray`
Masked model to extract.
Returns
-------
sp : dict
Dictionary of the extracted 1D spectra. Keys are the grism
names and the values are `~astropy.table.Table` objects.
"""
import astropy.units as u
# "Flat" spectrum to perform flux calibration
if self.Nphot > 0:
flat_data = self.flat_flam[self.fit_mask[: -self.Nphotbands]]
else:
flat_data = self.flat_flam[self.fit_mask]
sp_flat = self.optimal_extract(flat_data, **kwargs)
# Best-fit line and continuum models, with background fit
if tfit is not None:
bg_model = self.get_flat_background(tfit["coeffs"], apply_mask=True)
line_model = self.get_flat_model([tfit["line1d"].wave, tfit["line1d"].flux])
cont_model = self.get_flat_model([tfit["line1d"].wave, tfit["cont1d"].flux])
sp_line = self.optimal_extract(line_model, **kwargs)
sp_cont = self.optimal_extract(cont_model, **kwargs)
elif masked_model is not None:
bg_model = 0.0
sp_model = self.optimal_extract(masked_model, **kwargs)
else:
bg_model = 0.0
# Optimal spectral extraction
sp = self.optimal_extract(self.scif_mask[: self.Nspec] - bg_model, **kwargs)
if get_contam:
spc = self.optimal_extract(self.contamf_mask[: self.Nspec], **kwargs)
if (tfit is not None) & (get_background):
bgm = self.get_flat_background(tfit["coeffs"], apply_mask=True)
sp_bg = self.optimal_extract(bgm[: self.Nspec], **kwargs)
else:
sp_bg = None
# Loop through grisms, change units and add fit columns
# NB: setting units to "count / s" to comply with FITS standard,
# where count / s = electron / s
for k in sp:
sp[k]["flat"] = sp_flat[k]["flux"]
flat_unit = (u.count / u.s) / (u.erg / u.s / u.cm ** 2 / u.AA)
sp[k]["flat"].unit = flat_unit
sp[k]["flux"].unit = u.count / u.s
sp[k]["err"].unit = u.count / u.s
if get_contam:
sp[k]["contam"] = spc[k]["flux"]
sp[k]["contam"].unit = u.count / u.s
if tfit is not None:
sp[k]["line"] = sp_line[k]["flux"]
sp[k]["line"].unit = u.count / u.s
sp[k]["cont"] = sp_cont[k]["flux"]
sp[k]["cont"].unit = u.count / u.s
if masked_model is not None:
sp[k]["model"] = sp_model[k]["flux"]
sp[k]["model"].unit = u.count / u.s
if sp_bg is not None:
sp[k]["background"] = sp_bg[k]["flux"]
sp[k]["background"].unit = u.count / u.s
sp[k].meta["GRISM"] = (k, "Grism name")
# Metadata
exptime = count = 0
for pa in self.PA[k]:
for i in self.PA[k][pa]:
exptime += self.beams[i].grism.header["EXPTIME"]
count += 1
parent = (self.beams[i].grism.parent_file, "Parent file")
sp[k].meta["FILE{0:04d}".format(count)] = parent
sp[k].meta["NEXP"] = (count, "Number of exposures")
sp[k].meta["EXPTIME"] = (exptime, "Total exposure time")
sp[k].meta["NPA"] = (len(self.PA[k]), "Number of PAs")
# PSCALE
if hasattr(self, "pscale"):
if self.pscale is not None:
pscale = self.compute_scale_array(self.pscale, sp[k]["wave"])
sp[k]["pscale"] = pscale
sp[k].meta["PSCALEN"] = (len(self.pscale) - 1, "PSCALE order")
for i, p in enumerate(self.pscale):
sp[k].meta["PSCALE{0}".format(i)] = (
p,
"PSCALE parameter {0}".format(i),
)
return sp
[docs] def oned_spectrum_to_hdu(self, sp=None, outputfile=None, **kwargs):
"""
Generate 1D spectra fits HDUList
Parameters
----------
sp : optional, dict
Output of `~grizli.multifit.MultiBeam.oned_spectrum`. If None,
then run that function with `**kwargs`.
outputfile : None, str
If a string supplied, then write the `~astropy.io.fits.HDUList` to
a file.
Returns
-------
hdul : `~astropy.io.fits.HDUList`
FITS version of the 1D spectrum tables.
"""
from astropy.io.fits.convenience import table_to_hdu
# Generate the spectrum if necessary
if sp is None:
sp = self.oned_spectrum(**kwargs)
# Metadata in PrimaryHDU
prim = pyfits.PrimaryHDU()
prim.header["ID"] = (self.id, "Object ID")
prim.header["RA"] = (self.ra, "Right Ascension")
prim.header["DEC"] = (self.dec, "Declination")
prim.header["TARGET"] = (self.group_name, "Target Name")
prim.header["MW_EBV"] = (self.MW_EBV, "Galactic extinction E(B-V)")
for g in ["G102", "G141", "G800L"]:
if g in sp:
prim.header["N_{0}".format(g)] = sp[g].meta["NEXP"]
prim.header["T_{0}".format(g)] = sp[g].meta["EXPTIME"]
prim.header["PA_{0}".format(g)] = sp[g].meta["NPA"]
else:
prim.header["N_{0}".format(g)] = (0, "Number of exposures")
prim.header["T_{0}".format(g)] = (0, "Total exposure time")
prim.header["PA_{0}".format(g)] = (0, "Number of PAs")
for i, k in enumerate(sp):
prim.header["GRISM{0:03d}".format(i + 1)] = (k, "Grism name")
# Generate HDUList
hdul = [prim]
for k in sp:
hdu = table_to_hdu(sp[k])
hdu.header["EXTNAME"] = k
hdul.append(hdu)
# Outputs
hdul = pyfits.HDUList(hdul)
if outputfile is None:
return hdul
else:
hdul.writeto(outputfile, overwrite=True)
return hdul
[docs] def make_simple_hdulist(self):
"""
Make a`~astropy.io.fits.HDUList` object with just a simple
PrimaryHDU
"""
p = pyfits.PrimaryHDU()
p.header["ID"] = (self.id, "Object ID")
p.header["RA"] = (self.ra, "R.A.")
p.header["DEC"] = (self.dec, "Decl.")
p.header["NINPUT"] = (len(self.beams), "Number of drizzled beams")
p.header["HASLINES"] = ("", "Lines in this file")
for i, beam in enumerate(self.beams):
p.header["FILE{0:04d}".format(i + 1)] = (
beam.grism.parent_file,
"Parent filename",
)
p.header["GRIS{0:04d}".format(i + 1)] = (
beam.grism.filter,
"Beam grism element",
)
p.header["PA{0:04d}".format(i + 1)] = (
beam.get_dispersion_PA(),
"PA of dispersion axis",
)
return pyfits.HDUList(p)
[docs] def check_for_bad_PAs(
self, poly_order=1, chi2_threshold=1.5, fit_background=True, reinit=True
):
"""
Check for bad PAs based on chi-squared values
Parameters
----------
poly_order : int
Polynomial order for the background fit.
chi2_threshold : float
Threshold for chi-squared ratio to the minimum value.
fit_background : bool
Fit a polynomial background.
reinit : bool
Reinitialize the `~grizli.multifit.MultiBeam` object with thes
beams that pass the threshold.
Returns
-------
fit_log : dict
Log of the fit results.
keep_dict : dict
Dictionary of PAs to keep.
has_bad : bool
True if any PAs are flagged as bad.
"""
wave = np.linspace(2000, 2.5e4, 100)
poly_templates = utils.polynomial_templates(wave, order=poly_order)
fit_log = OrderedDict()
keep_dict = {}
has_bad = False
keep_beams = []
for g in self.PA:
fit_log[g] = OrderedDict()
keep_dict[g] = []
for pa in self.PA[g]:
beams = [self.beams[i] for i in self.PA[g][pa]]
mb_i = MultiBeam(
beams,
fcontam=self.fcontam,
sys_err=self.sys_err,
min_sens=self.min_sens,
min_mask=self.min_mask,
mask_resid=self.mask_resid,
MW_EBV=self.MW_EBV,
)
try:
chi2, _, _, _ = mb_i.xfit_at_z(
z=0, templates=poly_templates, fit_background=fit_background
)
except:
chi2 = 1e30
if False:
p_i = mb_i.template_at_z(
z=0,
templates=poly_templates,
fit_background=fit_background,
fitter="lstsq",
fwhm=1400,
get_uncertainties=2,
)
fit_log[g][pa] = {
"chi2": chi2,
"DoF": mb_i.DoF,
"chi_nu": chi2 / np.maximum(mb_i.DoF, 1),
}
min_chinu = 1e30
for pa in self.PA[g]:
min_chinu = np.minimum(min_chinu, fit_log[g][pa]["chi_nu"])
fit_log[g]["min_chinu"] = min_chinu
for pa in self.PA[g]:
fit_log[g][pa]["chinu_ratio"] = fit_log[g][pa]["chi_nu"] / min_chinu
if fit_log[g][pa]["chinu_ratio"] < chi2_threshold:
keep_dict[g].append(pa)
keep_beams.extend([self.beams[i] for i in self.PA[g][pa]])
else:
has_bad = True
if reinit:
self.beams = keep_beams
self._parse_beams(psf=self.psf_param_dict is not None)
return fit_log, keep_dict, has_bad
[docs]def get_redshift_fit_defaults():
"""
Default parameters for redshift fitting functions in `~grizli.multifit`.
"""
pzfit_def = dict(
zr=[0.5, 1.6],
dz=[0.005, 0.0004],
fwhm=0,
poly_order=0,
fit_background=True,
delta_chi2_threshold=0.004,
fitter="nnls",
prior=None,
templates={},
figsize=[8, 5],
fsps_templates=False,
)
pspec2_def = dict(dlam=0, spatial_scale=1, NY=20, figsize=[8, 3.5])
pline_def = dict(size=20, pixscale=0.1, pixfrac=0.2, kernel="square", wcs=None)
return pzfit_def, pspec2_def, pline_def
[docs]def drizzle_2d_spectrum(
beams,
data=None,
wlimit=[1.05, 1.75],
dlam=50,
spatial_scale=1,
NY=10,
pixfrac=0.6,
kernel="square",
convert_to_flambda=True,
fcontam=0.2,
fill_wht=False,
ds9=None,
**kwargs,
):
"""
Drizzle 2D spectrum from a list of beams
Parameters
----------
beams : list of `~.model.BeamCutout` objects
data : None or list
optionally, drizzle data specified in this list rather than the
contamination-subtracted arrays from each beam.
wlimit : [float, float]
Limits on the wavelength array to drizzle ([wlim, wmax])
dlam : float
Delta wavelength per pixel
spatial_scale : float
Relative scaling of the spatial axis (1 = native pixels)
NY : int
Size of the cutout in the spatial dimension, in output pixels
pixfrac : float
Drizzle PIXFRAC (for `kernel` = 'point')
kernel : str, ('square' or 'point')
Drizzle kernel to use
convert_to_flambda : bool, float
Convert the 2D spectrum to physical units using the sensitivity curves
and if float provided, scale the flux densities by that value
fcontam: float
Factor by which to scale the contamination arrays and add to the
pixel variances.
fill_wht: bool
Fill `wht==0` pixels of the beam weights with the median nonzero
value.
ds9: `~grizli.ds9.DS9`
Show intermediate steps of the drizzling
mask_segmentation: bool
Mask the segmentation image of the beams with the drizzled weight.
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDUList with the drizzled 2D spectrum and weight arrays
"""
from astropy import log
from drizzlepac import adrizzle
adrizzle.log.setLevel("ERROR")
drizzler = adrizzle.do_driz
dfillval = 0
log.setLevel("ERROR")
# log.disable_warnings_logging()
NX = int(np.round(np.diff(wlimit)[0] * 1.0e4 / dlam)) // 2
center = np.mean(wlimit[:2]) * 1.0e4
out_header, output_wcs = utils.full_spectrum_wcsheader(
center_wave=center, dlam=dlam, NX=NX, spatial_scale=spatial_scale, NY=NY
)
sh = (out_header["NAXIS2"], out_header["NAXIS1"])
if not hasattr(output_wcs, "_naxis1"):
output_wcs._naxis2, output_wcs._naxis1 = sh
outsci = np.zeros(sh, dtype=np.float32)
outwht = np.zeros(sh, dtype=np.float32)
outctx = np.zeros(sh, dtype=np.int32)
outvar = np.zeros(sh, dtype=np.float32)
outwv = np.zeros(sh, dtype=np.float32)
outcv = np.zeros(sh, dtype=np.int32)
if data is None:
data = []
for i, beam in enumerate(beams):
# Contamination-subtracted
beam_data = beam.grism.data["SCI"] - beam.contam
data.append(beam_data)
for i, beam in enumerate(beams):
# Get specific WCS for each beam
beam_header, beam_wcs = beam.full_2d_wcs()
if not hasattr(beam_wcs, "pixel_shape"):
beam_wcs.pixel_shape = beam_wcs._naxis1, beam_wcs._naxis2
if not hasattr(beam_wcs, "_naxis1"):
beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis
# Downweight contamination
# wht = 1/beam.ivar + (fcontam*beam.contam)**2
# wht = np.asarray(1/wht,dtype=np.float32)
# wht[~np.isfinite(wht)] = 0.
contam_weight = np.exp(-(fcontam * np.abs(beam.contam) * np.sqrt(beam.ivar)))
wht = beam.ivar * contam_weight
wht[~np.isfinite(wht)] = 0.0
contam_weight[beam.ivar == 0] = 0
if fill_wht:
wht_mask = wht == 0
med_wht = np.median(wht[~wht_mask])
wht[wht_mask] = med_wht
# print('xx Fill weight: {0}'.format(med_wht))
data_i = data[i] * 1.0
scl = 1.0
if convert_to_flambda:
# data_i *= convert_to_flambda/beam.beam.sensitivity
# wht *= (beam.beam.sensitivity/convert_to_flambda)**2
scl = convert_to_flambda # /1.e-17
scl *= 1.0 / beam.flat_flam.reshape(beam.beam.sh_beam).sum(axis=0)
# scl = convert_to_flambda/beam.beam.sensitivity
data_i *= scl
wht *= (1 / scl) ** 2
# contam_weight *= scl
wht[~np.isfinite(data_i + scl)] = 0
contam_weight[~np.isfinite(data_i + scl)] = 0
data_i[~np.isfinite(data_i + scl)] = 0
# Go drizzle
# Contamination-cleaned
drizzler(
data_i,
beam_wcs,
wht,
output_wcs,
outsci,
outwht,
outctx,
1.0,
"cps",
1,
wcslin_pscale=1.0,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
)
# For variance
drizzler(
contam_weight,
beam_wcs,
wht,
output_wcs,
outvar,
outwv,
outcv,
1.0,
"cps",
1,
wcslin_pscale=1.0,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
)
if ds9 is not None:
ds9.view(outsci / output_wcs.pscale ** 2, header=out_header)
# if False:
# # Plot the spectra for testing
# w, f, e = beam.beam.trace_extract(data_i, ivar=wht, r=3)
# clip = (f/e > 0.5)
# clip &= (e < 2*np.median(e[clip]))
# plt.errorbar(w[clip], f[clip], e[clip], marker='.', color='k', alpha=0.5, ecolor='0.8', linestyle='None')
# dw = np.median(np.diff(w))
# Correct for drizzle scaling
area_ratio = 1.0 / output_wcs.pscale ** 2
# Preserve flux (has to preserve aperture flux along spatial axis but
# average in spectral axis).
# area_ratio *= spatial_scale
# preserve flux density
flux_density_scale = spatial_scale ** 2
# science
outsci *= area_ratio * flux_density_scale
# variance
outvar *= area_ratio / outwv * flux_density_scale ** 2
outwht = 1 / outvar
outwht[(outvar == 0) | (~np.isfinite(outwht))] = 0
# if True:
# # Plot for testing....
# yp, xp = np.indices(outsci.shape)
# mask = np.abs(yp-NY) <= 3/spatial_scale
# fl = (outsci*mask).sum(axis=0)
# flv = (1/outwht*mask).sum(axis=0)
#
# wi = grizli.stack.StackedSpectrum.get_wavelength_from_header(out_header)
#
# plt.errorbar(wi[:-1], fl[1:], np.sqrt(flv)[1:], alpha=0.8) #*area_ratio)
# return outwht, outsci, outvar, outwv, output_wcs.pscale
p = pyfits.PrimaryHDU()
p.header["ID"] = (beams[0].id, "Object ID")
p.header["WMIN"] = (wlimit[0], "Minimum wavelength")
p.header["WMAX"] = (wlimit[1], "Maximum wavelength")
p.header["DLAM"] = (dlam, "Delta wavelength")
p.header["SSCALE"] = (spatial_scale, "Spatial scale factor w.r.t native")
p.header["FCONTAM"] = (fcontam, "Contamination weight")
p.header["PIXFRAC"] = (pixfrac, "Drizzle PIXFRAC")
p.header["DRIZKRNL"] = (kernel, "Drizzle kernel")
p.header["BEAM"] = (beams[0].beam.beam, "Grism order")
p.header["NINPUT"] = (len(beams), "Number of drizzled beams")
exptime = 0.0
for i, beam in enumerate(beams):
p.header["FILE{0:04d}".format(i + 1)] = (
beam.grism.parent_file,
"Parent filename",
)
p.header["GRIS{0:04d}".format(i + 1)] = (
beam.grism.filter,
"Beam grism element",
)
p.header["PA{0:04d}".format(i + 1)] = (
beam.get_dispersion_PA(),
"PA of dispersion axis",
)
exptime += beam.grism.exptime
p.header["EXPTIME"] = (exptime, "Total exposure time [s]")
h = out_header.copy()
grism_sci = pyfits.ImageHDU(data=outsci, header=h, name="SCI")
grism_wht = pyfits.ImageHDU(data=outwht, header=h, name="WHT")
hdul = pyfits.HDUList([p, grism_sci, grism_wht])
return hdul
[docs]def drizzle_to_wavelength(
beams,
wcs=None,
ra=0.0,
dec=0.0,
wave=1.0e4,
size=5,
pixscale=0.1,
pixfrac=0.6,
kernel="square",
theta=0.0,
direct_extension="REF",
fcontam=0.2,
custom_key=None,
ds9=None,
):
"""
Drizzle a cutout at a specific wavelength from a list of `~grizli.model.BeamCutout` objects
Parameters
----------
beams : list of `~.model.BeamCutout` objects.
wcs : `~astropy.wcs.WCS` or None
Pre-determined WCS. If not specified, generate one based on ``ra``,
``dec``, ``pixscale`` and ``pixscale``.
ra, dec, wave : float
Sky coordinates and central wavelength
size : float
Size of the output thumbnail, in arcsec
pixscale : float
Pixel scale of the output thumbnail, in arcsec
pixfrac : float
Drizzle PIXFRAC (for ``kernel`` = 'point')
kernel : str, ('square' or 'point')
Drizzle kernel to use
theta : float
Position angle of output WCS
direct_extension : str, ('SCI' or 'REF')
Extension of ``self.direct.data`` do drizzle for the thumbnail
fcontam: float
Factor by which to scale the contamination arrays and add to the
pixel variances.
custom_key : str
Key of `beam.grism.data` dictionary to use as the science array for the
drizzled output
ds9 : `~grizli.ds9.DS9`, optional
Display each step of the drizzling to an open DS9 window
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDUList with the drizzled thumbnail, line and continuum
cutouts.
"""
# try:
# import drizzle
# if drizzle.__version__ != '1.12.99':
# # Not the fork that works for all input/output arrays
# raise(ImportError)
#
# #print('drizzle!!')
# from drizzle.dodrizzle import dodrizzle
# drizzler = dodrizzle
# dfillval = '0'
# except:
from drizzlepac import adrizzle
adrizzle.log.setLevel("ERROR")
drizzler = adrizzle.do_driz
dfillval = 0
# Nothing to do
if len(beams) == 0:
return False
# Get output header and WCS
if wcs is None:
header, output_wcs = utils.make_wcsheader(
ra=ra, dec=dec, size=size, theta=theta, pixscale=pixscale, get_hdu=False
)
else:
output_wcs = wcs.copy()
if not hasattr(output_wcs, "pscale"):
output_wcs.pscale = utils.get_wcs_pscale(output_wcs)
header = utils.to_header(output_wcs, relax=True)
if not hasattr(output_wcs, "_naxis1"):
output_wcs._naxis1, output_wcs._naxis2 = output_wcs._naxis
# Initialize data
sh = (header["NAXIS2"], header["NAXIS1"])
outsci = np.zeros(sh, dtype=np.float32)
outwht = np.zeros(sh, dtype=np.float32)
outctx = np.zeros(sh, dtype=np.int32)
coutsci = np.zeros(sh, dtype=np.float32)
coutwht = np.zeros(sh, dtype=np.float32)
coutctx = np.zeros(sh, dtype=np.int32)
xoutsci = np.zeros(sh, dtype=np.float32)
xoutwht = np.zeros(sh, dtype=np.float32)
xoutctx = np.zeros(sh, dtype=np.int32)
# direct_filters = np.unique([b.direct.filter for b in self.beams])
all_direct_filters = []
for beam in beams:
if direct_extension == "REF":
if beam.direct["REF"] is None:
filt_i = beam.direct.ref_filter
direct_extension = "SCI"
else:
filt_i = beam.direct.filter
all_direct_filters.append(filt_i)
direct_filters = np.unique(all_direct_filters)
doutsci, doutwht, doutctx = {}, {}, {}
for f in direct_filters:
doutsci[f] = np.zeros(sh, dtype=np.float32)
doutwht[f] = np.zeros(sh, dtype=np.float32)
doutctx[f] = np.zeros(sh, dtype=np.int32)
# doutsci = np.zeros(sh, dtype=np.float32)
# doutwht = np.zeros(sh, dtype=np.float32)
# doutctx = np.zeros(sh, dtype=np.int32)
# Loop through beams and run drizzle
for i, beam in enumerate(beams):
# Get specific wavelength WCS for each beam
beam_header, beam_wcs = beam.get_wavelength_wcs(wave)
if not hasattr(beam_wcs, "pixel_shape"):
beam_wcs.pixel_shape = beam_wcs._naxis1, beam_wcs._naxis2
if not hasattr(beam_wcs, "_naxis1"):
beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis
# Make sure CRPIX set correctly for the SIP header
for j in [0, 1]:
# if beam_wcs.sip is not None:
# beam_wcs.sip.crpix[j] = beam_wcs.wcs.crpix[j]
if beam.direct.wcs.sip is not None:
beam.direct.wcs.sip.crpix[j] = beam.direct.wcs.wcs.crpix[j]
for wcs_ext in [beam_wcs.sip]:
if wcs_ext is not None:
wcs_ext.crpix[j] = beam_wcs.wcs.crpix[j]
# ACS requires additional wcs attributes
ACS_CRPIX = [4096 / 2, 2048 / 2]
dx_crpix = beam_wcs.wcs.crpix[0] - ACS_CRPIX[0]
dy_crpix = beam_wcs.wcs.crpix[1] - ACS_CRPIX[1]
for wcs_ext in [
beam_wcs.cpdis1,
beam_wcs.cpdis2,
beam_wcs.det2im1,
beam_wcs.det2im2,
]:
if wcs_ext is not None:
wcs_ext.crval[0] += dx_crpix
wcs_ext.crval[1] += dy_crpix
if custom_key is None:
beam_data = beam.grism.data["SCI"] - beam.contam
if hasattr(beam, "background"):
beam_data -= beam.background
if hasattr(beam, "extra_lines"):
beam_data -= beam.extra_lines
beam_continuum = beam.beam.model * 1
if hasattr(beam.beam, "pscale_array"):
beam_continuum *= beam.beam.pscale_array
else:
beam_data = beam.grism.data[custom_key] * 1
beam_continuum = beam_data * 0
# Downweight contamination
if fcontam > 0:
# wht = 1/beam.ivar + (fcontam*beam.contam)**2
# wht = np.asarray(1/wht,dtype=np.float32)
# wht[~np.isfinite(wht)] = 0.
contam_weight = np.exp(
-(fcontam * np.abs(beam.contam) * np.sqrt(beam.ivar))
)
wht = beam.ivar * contam_weight
wht[~np.isfinite(wht)] = 0.0
else:
wht = beam.ivar * 1
# Convert to f_lambda integrated line fluxes:
# (Inverse of the aXe sensitivity) x (size of pixel in \AA)
sens = np.interp(wave, beam.beam.lam, beam.beam.sensitivity, left=0, right=0)
dlam = np.interp(wave, beam.beam.lam[1:], np.diff(beam.beam.lam))
# 1e-17 erg/s/cm2 #, scaling closer to e-/s
sens *= 1.0e-17
sens *= 1.0 / dlam
if sens == 0:
continue
else:
wht *= sens ** 2
beam_data /= sens
beam_continuum /= sens
# Go drizzle
# Contamination-cleaned
drizzler(
beam_data,
beam_wcs,
wht,
output_wcs,
outsci,
outwht,
outctx,
1.0,
"cps",
1,
wcslin_pscale=beam.grism.wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Continuum
drizzler(
beam_continuum,
beam_wcs,
wht,
output_wcs,
coutsci,
coutwht,
coutctx,
1.0,
"cps",
1,
wcslin_pscale=beam.grism.wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Contamination
drizzler(
beam.contam,
beam_wcs,
wht,
output_wcs,
xoutsci,
xoutwht,
xoutctx,
1.0,
"cps",
1,
wcslin_pscale=beam.grism.wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Direct thumbnail
filt_i = all_direct_filters[i]
if direct_extension == "REF":
thumb = beam.direct["REF"]
thumb_wht = np.asarray((thumb != 0) * 1, dtype=np.float32)
else:
thumb = beam.direct[direct_extension] # /beam.direct.photflam
thumb_wht = 1.0 / (beam.direct.data["ERR"] / beam.direct.photflam) ** 2
thumb_wht[~np.isfinite(thumb_wht)] = 0
if not hasattr(beam.direct.wcs, "pixel_shape"):
beam.direct.wcs.pixel_shape = (
beam.direct.wcs._naxis1,
beam.direct.wcs._naxis2,
)
if not hasattr(beam.direct.wcs, "_naxis1"):
beam.direct.wcs._naxis1, beam.direct.wcs._naxis2 = beam.direct.wcs._naxis
drizzler(
thumb,
beam.direct.wcs,
thumb_wht,
output_wcs,
doutsci[filt_i],
doutwht[filt_i],
doutctx[filt_i],
1.0,
"cps",
1,
wcslin_pscale=beam.direct.wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Show in ds9
if ds9 is not None:
ds9.view((outsci - coutsci), header=header)
# Scaling of drizzled outputs
outwht *= (beams[0].grism.wcs.pscale / output_wcs.pscale) ** 4
coutwht *= (beams[0].grism.wcs.pscale / output_wcs.pscale) ** 4
xoutwht *= (beams[0].grism.wcs.pscale / output_wcs.pscale) ** 4
for filt_i in all_direct_filters:
doutwht[filt_i] *= (beams[0].direct.wcs.pscale / output_wcs.pscale) ** 4
# Make output FITS products
p = pyfits.PrimaryHDU()
p.header["ID"] = (beams[0].id, "Object ID")
p.header["RA"] = (ra, "Central R.A.")
p.header["DEC"] = (dec, "Central Decl.")
p.header["PIXFRAC"] = (pixfrac, "Drizzle PIXFRAC")
p.header["DRIZKRNL"] = (kernel, "Drizzle kernel")
p.header["NINPUT"] = (len(beams), "Number of drizzled beams")
for i, beam in enumerate(beams):
p.header["FILE{0:04d}".format(i + 1)] = (
beam.grism.parent_file,
"Parent filename",
)
p.header["GRIS{0:04d}".format(i + 1)] = (
beam.grism.filter,
"Beam grism element",
)
p.header["PA{0:04d}".format(i + 1)] = (
beam.get_dispersion_PA(),
"PA of dispersion axis",
)
h = header.copy()
h["ID"] = (beam.id, "Object ID")
h["PIXFRAC"] = (pixfrac, "Drizzle PIXFRAC")
h["DRIZKRNL"] = (kernel, "Drizzle kernel")
p.header["NDFILT"] = len(direct_filters), "Number of direct image filters"
for i, filt_i in enumerate(direct_filters):
p.header["DFILT{0:02d}".format(i + 1)] = filt_i
p.header["NFILT{0:02d}".format(i + 1)] = (
all_direct_filters.count(filt_i),
"Number of beams with this direct filter",
)
HDUL = [p]
for i, filt_i in enumerate(direct_filters):
h["FILTER"] = (filt_i, "Direct image filter")
thumb_sci = pyfits.ImageHDU(data=doutsci[filt_i], header=h, name="DSCI")
thumb_wht = pyfits.ImageHDU(data=doutwht[filt_i], header=h, name="DWHT")
thumb_sci.header["EXTVER"] = filt_i
thumb_wht.header["EXTVER"] = filt_i
HDUL += [thumb_sci, thumb_wht]
# thumb_seg = pyfits.ImageHDU(data=seg_slice, header=h, name='DSEG')
h["FILTER"] = (beam.grism.filter, "Grism filter")
h["WAVELEN"] = (wave, "Central wavelength")
grism_sci = pyfits.ImageHDU(data=outsci - coutsci, header=h, name="LINE")
grism_cont = pyfits.ImageHDU(data=coutsci, header=h, name="CONTINUUM")
grism_contam = pyfits.ImageHDU(data=xoutsci, header=h, name="CONTAM")
grism_wht = pyfits.ImageHDU(data=outwht, header=h, name="LINEWHT")
# HDUL = [p, thumb_sci, thumb_wht, grism_sci, grism_cont, grism_contam, grism_wht]
HDUL += [grism_sci, grism_cont, grism_contam, grism_wht]
return pyfits.HDUList(HDUL)
[docs]def show_drizzle_HDU(
hdu,
diff=True,
mask_segmentation=False,
average_only=False,
scale_size=1,
cmap="viridis_r",
show_labels=True,
width_ratio=0.2,
**kwargs,
):
"""
Make a figure from the multiple extensions in the drizzled grism file.
Parameters
----------
hdu : `~astropy.io.fits.HDUList`
HDU list output by `drizzle_grisms_and_PAs`.
diff : bool
If True, then plot the stacked spectrum minus the model.
mask_segmentation : bool
Mask the segmentation image of the beams with the drizzled weight.
Not used.
average_only : bool
If True, then only show the average spectrum for each grism.
scale_size : float
Scale the size of the figure.
cmap : str
Colormap for the 2D spectra.
show_labels : bool
Show the wavelength labels.
width_ratio : float
Width ratio of the kernel and spectrum panels.
kwargs : dict
Passed to `~matplotlib.pyplot.imshow`.
Returns
-------
fig : `~matplotlib.figure.Figure`
The figure.
"""
from collections import OrderedDict
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MultipleLocator
h0 = hdu[0].header
NX = h0["NGRISM"]
NY = 0
grisms = OrderedDict()
for ig in range(NX):
g = h0["GRISM{0:03d}".format(ig + 1)]
NY = np.maximum(NY, h0["N" + g])
grisms[g] = h0["N" + g]
NY += 1
widths = []
for i in range(NX):
widths.extend([width_ratio, 1])
if average_only:
NY = 1
fig = plt.figure(
figsize=(NX * scale_size / width_ratio, NY * scale_size + 0.33)
)
gs = GridSpec(NY, NX * 2, width_ratios=widths)
else:
fig = plt.figure(figsize=(NX * scale_size / width_ratio, 1 * NY * scale_size))
gs = GridSpec(NY, NX * 2, height_ratios=[1] * NY, width_ratios=widths)
for ig, g in enumerate(grisms):
sci_i = hdu["SCI", g]
wht_i = hdu["WHT", g]
model_i = hdu["MODEL", g]
kern_i = hdu["KERNEL", g]
h_i = sci_i.header
clip = wht_i.data > 0
if clip.sum() == 0:
clip = np.isfinite(wht_i.data)
avg_rms = 1 / np.median(np.sqrt(wht_i.data[clip]))
vmax = np.maximum(1.1 * np.percentile(sci_i.data[clip], 98), 5 * avg_rms)
vmax_kern = 1.1 * np.percentile(kern_i.data, 99.5)
# Kernel
ax = fig.add_subplot(gs[NY - 1, ig * 2 + 0])
sh = kern_i.data.shape
extent = [0, sh[1], 0, sh[0]]
ax.imshow(
kern_i.data,
origin="lower",
interpolation="Nearest",
vmin=-0.1 * vmax_kern,
vmax=vmax_kern,
cmap=cmap,
extent=extent,
aspect="auto",
)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.xaxis.set_tick_params(length=0)
ax.yaxis.set_tick_params(length=0)
# Spectrum
sh = sci_i.data.shape
extent = [h_i["WMIN"], h_i["WMAX"], 0, sh[0]]
ax = fig.add_subplot(gs[NY - 1, ig * 2 + 1])
if diff:
# print('xx DIFF!')
m = model_i.data
else:
m = 0
ax.imshow(
sci_i.data - m,
origin="lower",
interpolation="Nearest",
vmin=-0.1 * vmax,
vmax=vmax,
extent=extent,
cmap=cmap,
aspect="auto",
)
ax.set_yticklabels([])
ax.set_xlabel(r"$\lambda$ ($\mu$m) - " + g)
ax.xaxis.set_major_locator(MultipleLocator(GRISM_MAJOR[g]))
if average_only:
iters = []
else:
iters = range(grisms[g])
for ip in iters:
# print(ip, ig)
pa = h0["{0}{1:02d}".format(g, ip + 1)]
sci_i = hdu["SCI", "{0},{1}".format(g, pa)]
wht_i = hdu["WHT", "{0},{1}".format(g, pa)]
kern_i = hdu["KERNEL", "{0},{1}".format(g, pa)]
h_i = sci_i.header
# Kernel
ax = fig.add_subplot(gs[ip, ig * 2 + 0])
sh = kern_i.data.shape
extent = [0, sh[1], 0, sh[0]]
ax.imshow(
kern_i.data,
origin="lower",
interpolation="Nearest",
vmin=-0.1 * vmax_kern,
vmax=vmax_kern,
extent=extent,
cmap=cmap,
aspect="auto",
)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.xaxis.set_tick_params(length=0)
ax.yaxis.set_tick_params(length=0)
# Spectrum
sh = sci_i.data.shape
extent = [h_i["WMIN"], h_i["WMAX"], 0, sh[0]]
ax = fig.add_subplot(gs[ip, ig * 2 + 1])
ax.imshow(
sci_i.data,
origin="lower",
interpolation="Nearest",
vmin=-0.1 * vmax,
vmax=vmax,
extent=extent,
cmap=cmap,
aspect="auto",
)
ax.set_yticklabels([])
ax.set_xticklabels([])
ax.xaxis.set_major_locator(MultipleLocator(GRISM_MAJOR[g]))
if show_labels:
ax.text(
0.015,
0.94,
"{0:3.0f}".format(pa),
ha="left",
va="top",
transform=ax.transAxes,
fontsize=8,
backgroundcolor="w",
)
if (ig == (NX - 1)) & (ip == 0) & show_labels:
ax.text(
0.98,
0.94,
"ID = {0}".format(h0["ID"]),
ha="right",
va="top",
transform=ax.transAxes,
fontsize=8,
backgroundcolor="w",
)
if average_only:
# pass
gs.tight_layout(fig, pad=0.01)
else:
gs.tight_layout(fig, pad=0.1)
return fig
[docs]def drizzle_2d_spectrum_wcs(
beams,
data=None,
wlimit=[1.05, 1.75],
dlam=50,
spatial_scale=1,
NY=10,
pixfrac=0.6,
kernel="square",
convert_to_flambda=True,
fcontam=0.2,
fill_wht=False,
ds9=None,
mask_segmentation=True,
):
"""
Drizzle 2D spectrum from a list of beams
Parameters
----------
beams : list of `~.model.BeamCutout` objects
data : None or list
optionally, drizzle data specified in this list rather than the
contamination-subtracted arrays from each beam.
wlimit : [float, float]
Limits on the wavelength array to drizzle ([wlim, wmax])
dlam : float
Delta wavelength per pixel
spatial_scale : float
Relative scaling of the spatial axis (1 = native pixels)
NY : int
Size of the cutout in the spatial dimension, in output pixels
pixfrac : float
Drizzle PIXFRAC (for `kernel` = 'point')
kernel : str, ('square' or 'point')
Drizzle kernel to use
convert_to_flambda : bool, float
Convert the 2D spectrum to physical units using the sensitivity curves
and if float provided, scale the flux densities by that value
fcontam: float
Factor by which to scale the contamination arrays and add to the
pixel variances.
fill_wht : bool
Fill in the drizzled weight array with the sum of the input weights.
(Not used.)
ds9: `~grizli.ds9.DS9`
Show intermediate steps of the drizzling
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDUList with the drizzled 2D spectrum and weight arrays
"""
# try:
# import drizzle
# if drizzle.__version__ != '1.12.99':
# # Not the fork that works for all input/output arrays
# raise(ImportError)
#
# #print('drizzle!!')
# from drizzle.dodrizzle import dodrizzle
# drizzler = dodrizzle
# dfillval = '0'
# except:
from drizzlepac import adrizzle
adrizzle.log.setLevel("ERROR")
drizzler = adrizzle.do_driz
dfillval = 0
from stwcs import distortion
from astropy import log
log.setLevel("ERROR")
# log.disable_warnings_logging()
adrizzle.log.setLevel("ERROR")
NX = int(np.round(np.diff(wlimit)[0] * 1.0e4 / dlam)) // 2
center = np.mean(wlimit[:2]) * 1.0e4
out_header, output_wcs = utils.make_spectrum_wcsheader(
center_wave=center, dlam=dlam, NX=NX, spatial_scale=spatial_scale, NY=NY
)
pixscale = 0.128 * spatial_scale
# # Get central RA, reference pixel of beam[0]
# #rd = beams[0].get_sky_coords()
# x0 = beams[0].beam.x0.reshape((1,2))
# #x0[0,1] += beam.direct.origin[1]-beam.grism.origin[1]
# rd = beam.grism.wcs.all_pix2world(x0,1)[0]
# theta = 270-beams[0].get_dispersion_PA()
# out_header, output_wcs = utils.make_wcsheader(ra=rd[0], dec=rd[1], size=[50,10], pixscale=pixscale, get_hdu=False, theta=theta)
if True:
theta = -np.arctan2(np.diff(beams[0].beam.ytrace)[0], 1)
undist_wcs = distortion.utils.output_wcs([beams[0].grism.wcs], undistort=True)
undist_wcs = utils.transform_wcs(
undist_wcs, rotation=theta, scale=undist_wcs.pscale / pixscale
)
output_wcs = undist_wcs.copy()
out_header = utils.to_header(output_wcs)
# Direct image
d_undist_wcs = distortion.utils.output_wcs(
[beams[0].direct.wcs], undistort=True
)
d_undist_wcs = utils.transform_wcs(
d_undist_wcs, rotation=0.0, scale=d_undist_wcs.pscale / pixscale
)
d_output_wcs = d_undist_wcs.copy()
# Make square
if hasattr(d_output_wcs, "_naxis1"):
nx1, nx2 = d_output_wcs._naxis1, d_output_wcs._naxis2
else:
nx1, nx2 = d_output_wcs._naxis
d_output_wcs._naxis1, d_output_wcs._naxis2 = nx1, nx2
dx = nx1 - nx2
if hasattr(d_output_wcs, "_naxis1"):
d_output_wcs._naxis1 = d_output_wcs._naxis2
else:
d_output_wcs._naxis[0] = d_output_wcs._naxis[1]
d_output_wcs._naxis1 = d_output_wcs._naxis2 = d_output_wcs._naxis[0]
d_output_wcs.wcs.crpix[0] -= dx / 2.0
d_out_header = utils.to_header(d_output_wcs)
# delattr(output_wcs, 'orientat')
# beam_header = utils.to_header(beam_wcs)
# output_wcs = beam_wcs
# output_wcs = pywcs.WCS(beam_header, relax=True)
# output_wcs.pscale = utils.get_wcs_pscale(output_wcs)
# shift CRPIX to reference position of beam[0]
sh = (out_header["NAXIS2"], out_header["NAXIS1"])
sh_d = (d_out_header["NAXIS2"], d_out_header["NAXIS1"])
outsci = np.zeros(sh, dtype=np.float32)
outwht = np.zeros(sh, dtype=np.float32)
outctx = np.zeros(sh, dtype=np.int32)
doutsci = np.zeros(sh_d, dtype=np.float32)
doutwht = np.zeros(sh_d, dtype=np.float32)
doutctx = np.zeros(sh_d, dtype=np.int32)
outvar = np.zeros(sh, dtype=np.float32)
outwv = np.zeros(sh, dtype=np.float32)
outcv = np.zeros(sh, dtype=np.int32)
outls = np.zeros(sh, dtype=np.float32)
outlw = np.zeros(sh, dtype=np.float32)
outlc = np.zeros(sh, dtype=np.int32)
if data is None:
data = []
for i, beam in enumerate(beams):
# Contamination-subtracted
beam_data = beam.grism.data["SCI"] - beam.contam
data.append(beam_data)
for i, beam in enumerate(beams):
# Get specific WCS for each beam
beam_header, beam_wcs = beam.get_2d_wcs()
beam_wcs = beam.grism.wcs.deepcopy()
# Shift SIP reference
dx_sip = beam.grism.origin[1] - beam.direct.origin[1]
# beam_wcs.sip.crpix[0] += dx_sip
for wcs_ext in [beam_wcs.sip]:
if wcs_ext is not None:
wcs_ext.crpix[0] += dx_sip
for wcs_ext in [
beam_wcs.cpdis1,
beam_wcs.cpdis2,
beam_wcs.det2im1,
beam_wcs.det2im2,
]:
if wcs_ext is not None:
wcs_ext.crval[0] += dx_sip
# Shift y for trace
xy0 = beam.grism.wcs.all_world2pix(output_wcs.wcs.crval.reshape((1, 2)), 0)[0]
dy = np.interp(xy0[0], np.arange(beam.beam.sh_beam[1]), beam.beam.ytrace)
# beam_wcs.sip.crpix[1] += dy
beam_wcs.wcs.crpix[1] += dy
for wcs_ext in [beam_wcs.sip]:
if wcs_ext is not None:
wcs_ext.crpix[1] += dy
for wcs_ext in [
beam_wcs.cpdis1,
beam_wcs.cpdis2,
beam_wcs.det2im1,
beam_wcs.det2im2,
]:
if wcs_ext is not None:
wcs_ext.crval[1] += dy
if not hasattr(beam_wcs, "pixel_shape"):
beam_wcs.pixel_shape = beam_wcs._naxis1, beam_wcs._naxis2
if not hasattr(beam_wcs, "_naxis1"):
beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis
d_beam_wcs = beam.direct.wcs
if beam.direct["REF"] is None:
d_wht = 1.0 / beam.direct["ERR"] ** 2
d_wht[~np.isfinite(d_wht)] = 0
d_sci = beam.direct["SCI"] * 1
else:
d_sci = beam.direct["REF"] * 1
d_wht = d_sci * 0.0 + 1
if mask_segmentation:
d_sci *= beam.beam.seg == beam.id
# Downweight contamination
# wht = 1/beam.ivar + (fcontam*beam.contam)**2
# wht = np.asarray(1/wht,dtype=np.float32)
# wht[~np.isfinite(wht)] = 0.
contam_weight = np.exp(-(fcontam * np.abs(beam.contam) * np.sqrt(beam.ivar)))
wht = beam.ivar * contam_weight
wht[~np.isfinite(wht)] = 0.0
contam_weight[beam.ivar == 0] = 0
data_i = data[i] * 1.0
scl = 1.0
if convert_to_flambda:
# data_i *= convert_to_flambda/beam.beam.sensitivity
# wht *= (beam.beam.sensitivity/convert_to_flambda)**2
scl = convert_to_flambda # /1.e-17
scl *= 1.0 / beam.flat_flam.reshape(beam.beam.sh_beam).sum(axis=0)
# scl = convert_to_flambda/beam.beam.sensitivity
data_i *= scl
wht *= (1 / scl) ** 2
# contam_weight *= scl
wht[~np.isfinite(data_i + scl)] = 0
contam_weight[~np.isfinite(data_i + scl)] = 0
data_i[~np.isfinite(data_i + scl)] = 0
# Go drizzle
data_wave = np.dot(
np.ones(beam.beam.sh_beam[0])[:, None], beam.beam.lam[None, :]
)
drizzler(
data_wave,
beam_wcs,
wht * 0.0 + 1,
output_wcs,
outls,
outlw,
outlc,
1.0,
"cps",
1,
wcslin_pscale=1.0,
uniqid=1,
pixfrac=1,
kernel="square",
fillval=dfillval,
)
# Direct image
drizzler(
d_sci,
d_beam_wcs,
d_wht,
d_output_wcs,
doutsci,
doutwht,
doutctx,
1.0,
"cps",
1,
wcslin_pscale=d_beam_wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
)
# Contamination-cleaned
drizzler(
data_i,
beam_wcs,
wht,
output_wcs,
outsci,
outwht,
outctx,
1.0,
"cps",
1,
wcslin_pscale=beam_wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
)
# For variance
drizzler(
contam_weight,
beam_wcs,
wht,
output_wcs,
outvar,
outwv,
outcv,
1.0,
"cps",
1,
wcslin_pscale=beam_wcs.pscale,
uniqid=1,
pixfrac=pixfrac,
kernel=kernel,
fillval=dfillval,
)
if ds9 is not None:
ds9.view(outsci, header=out_header)
# if True:
# w, f, e = beam.beam.optimal_extract(data_i, ivar=beam.ivar)
# plt.scatter(w, f, marker='.', color='k', alpha=0.5)
# Correct for drizzle scaling
# outsci /= output_wcs.pscale**2
outls /= output_wcs.pscale ** 2
wave = np.median(outls, axis=0)
# # Testing
# fl = (sp[1].data*mask).sum(axis=0)
# variance
outvar /= outwv # *output_wcs.pscale**2
outwht = 1 / outvar
outwht[(outvar == 0) | (~np.isfinite(outwht))] = 0
# return outwht, outsci, outvar, outwv, output_wcs.pscale
p = pyfits.PrimaryHDU()
p.header["ID"] = (beams[0].id, "Object ID")
p.header["WMIN"] = (wave[0], "Minimum wavelength")
p.header["WMAX"] = (wave[-1], "Maximum wavelength")
p.header["DLAM"] = ((wave[-1] - wave[0]) / wave.size, "Delta wavelength")
p.header["FCONTAM"] = (fcontam, "Contamination weight")
p.header["PIXFRAC"] = (pixfrac, "Drizzle PIXFRAC")
p.header["DRIZKRNL"] = (kernel, "Drizzle kernel")
p.header["NINPUT"] = (len(beams), "Number of drizzled beams")
for i, beam in enumerate(beams):
p.header["FILE{0:04d}".format(i + 1)] = (
beam.grism.parent_file,
"Parent filename",
)
p.header["GRIS{0:04d}".format(i + 1)] = (
beam.grism.filter,
"Beam grism element",
)
h = out_header.copy()
for k in p.header:
h[k] = p.header[k]
direct_sci = pyfits.ImageHDU(data=doutsci, header=d_out_header, name="DSCI")
grism_sci = pyfits.ImageHDU(data=outsci, header=h, name="SCI")
grism_wht = pyfits.ImageHDU(data=outwht, header=h, name="WHT")
hdul = pyfits.HDUList([p, grism_sci, grism_wht, direct_sci])
return hdul