"""Functionality for manipulating multiple grism exposures simultaneously
"""
import os
import time
import traceback
import glob
from collections import OrderedDict
import multiprocessing as mp
from . import prep
import scipy.ndimage as nd
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table
import astropy.io.fits as pyfits
import astropy.wcs as pywcs
import astropy.units as u
# local imports
from . import utils
from . import model
from . import grismconf
#from . import stack
from .fitting import GroupFitter
from .utils_numba import interp
from .utils import GRISM_COLORS, GRISM_MAJOR, GRISM_LIMITS, DEFAULT_LINE_LIST
def _loadFLT(grism_file, sci_extn, direct_file, pad, ref_file,
ref_ext, seg_file, verbose, catalog, ix, use_jwst_crds):
"""Helper function for loading `.model.GrismFLT` objects with `multiprocessing`.
TBD
"""
import time
try:
import cPickle as pickle
except:
# Python 3
import pickle
# slight random delay to avoid synchronization problems
# np.random.seed(ix)
# sleeptime = ix*1
# print '%s sleep %.3f %d' %(grism_file, sleeptime, ix)
# time.sleep(sleeptime)
# print grism_file, direct_file
new_root = '.{0:02d}.GrismFLT.fits'.format(sci_extn)
save_file = grism_file.replace('_flt.fits', new_root)
save_file = save_file.replace('_flc.fits', new_root)
save_file = save_file.replace('_cmb.fits', new_root)
save_file = save_file.replace('_rate.fits', new_root)
save_file = save_file.replace('_elec.fits', new_root)
if (save_file == grism_file) & ('GrismFLT' not in grism_file):
# couldn't build new filename based on the extensions
# so just insert at the end
save_file = grism_file.replace('.fits', new_root)
if (grism_file.find('_') < 0) & ('GrismFLT' not in grism_file):
save_file = 'xxxxxxxxxxxxxxxxxxx'
if os.path.exists(save_file) & ('GrismFLT' in save_file):
print('Load {0}!'.format(save_file))
fp = open(save_file.replace('GrismFLT.fits', 'GrismFLT.pkl'), 'rb')
flt = pickle.load(fp)
fp.close()
status = flt.load_from_fits(save_file)
else:
flt = model.GrismFLT(grism_file=grism_file, sci_extn=sci_extn,
direct_file=direct_file, pad=pad,
ref_file=ref_file, ref_ext=ref_ext,
seg_file=seg_file, shrink_segimage=True,
verbose=verbose, use_jwst_crds=use_jwst_crds)
if flt.direct.wcs.wcs.has_pc():
for obj in [flt.grism, flt.direct]:
obj.get_wcs()
if catalog is not None:
flt.catalog = flt.blot_catalog(catalog,
sextractor=('X_WORLD' in catalog.colnames))
flt.catalog_file = catalog
else:
flt.catalog = None
if flt.grism.instrument in ['NIRCAM']:
flt.apply_POM()
if flt.grism.instrument in ['NIRISS', 'NIRCAM']:
flt.transform_JWST_WFSS()
if hasattr(flt, 'conf'):
delattr(flt, 'conf')
return flt # , out_cat
def _fit_at_z(self, zgrid, i, templates, fitter, fit_background, poly_order):
"""
For parallel processing
"""
# self, z=0., templates={}, fitter='nnls',
# fit_background=True, poly_order=0
print(i, zgrid[i])
out = self.fit_at_z(z=zgrid[i], templates=templates,
fitter=fitter, poly_order=poly_order,
fit_background=fit_background)
data = {'out': out, 'i': i}
return data
#A, coeffs[i,:], chi2[i], model_2d = out
def _beam_compute_model(beam, id, spectrum_1d, is_cgs, apply_sensitivity, scale, reset):
"""
wrapper function for multiprocessing
"""
beam.beam.compute_model(id=id, spectrum_1d=spectrum_1d,
is_cgs=is_cgs,
scale=scale, reset=reset,
apply_sensitivity=apply_sensitivity)
beam.modelf = beam.beam.modelf
beam.model = beam.beam.modelf.reshape(beam.beam.sh_beam)
return True
# def test_parallel():
#
# zgrid = np.linspace(1.1, 1.3, 10)
# templates = mb.load_templates(fwhm=800)
# fitter = 'nnls'
# fit_background = True
# poly_order = 0
#
# self.FLTs = []
# t0_pool = time.time()
#
# pool = mp.Pool(processes=4)
# results = [pool.apply_async(_fit_at_z, (mb, zgrid, i, templates, fitter, fit_background, poly_order)) for i in range(len(zgrid))]
#
# pool.close()
# pool.join()
#
# chi = zgrid*0.
#
# for res in results:
# data = res.get(timeout=1)
# A, coeffs, chi[data['i']], model_2d = data['out']
# #flt_i.catalog = cat_i
#
# t1_pool = time.time()
def _compute_model(i, flt, fit_info, is_cgs, store, model_kwargs):
"""Helper function for computing model orders.
"""
if not hasattr(flt, 'conf'):
flt.conf = grismconf.load_grism_config(flt.conf_file)
for id in fit_info:
try:
status = flt.compute_model_orders(id=id,
mag=fit_info[id]['mag'], in_place=True, store=store,
spectrum_1d=fit_info[id]['spec'], is_cgs=is_cgs,
verbose=False, **model_kwargs)
except:
print('Failed: {0} {1}'.format(flt.grism.parent_file, id))
continue
print('{0}: _compute_model Done'.format(flt.grism.parent_file))
return i, flt.model, flt.object_dispersers
[docs]class GroupFLT():
def __init__(self, grism_files=[], sci_extn=1, direct_files=[],
pad=(64,256), group_name='group',
ref_file=None, ref_ext=0, seg_file=None,
shrink_segimage=True, verbose=True, cpu_count=0,
catalog='', polyx=[0.3, 5.3],
MW_EBV=0., bits=None, use_jwst_crds=False):
"""Main container for handling multiple grism exposures together
Parameters
----------
grism_files : list
List of grism exposures (typically WFC3/IR "FLT" or ACS/UVIS "FLC"
files). These can be from different grisms and/or orients.
sci_extn : int
Science extension to extract from the files in `grism_files`. For
WFC3/IR this can only be 1, though for the two-chip instruments
WFC3/UVIS and ACS/WFC3 this can be 1 or 2.
direct_files : list
List of direct exposures (typically WFC3/IR "FLT" or ACS/UVIS
"FLC" files). This list should either be empty or should
correspond one-to-one with entries in the `grism_files` list,
i.e., from an undithered pair of direct and grism exposures. If
such pairs weren't obtained or if you simply wish to ignore them
and just use the `ref_file` reference image, set to an empty list
(`[]`).
pad : int, int
Padding in pixels to apply around the edge of the detector to
allow modeling of sources that fall off of the nominal FOV. For
this to work requires using a `ref_file` reference image that
covers this extra area. Specified in array axis order (pady, padx)
group_name : str
Name to apply to products produced by this group.
ref_file : `None` or str
Undistorted reference image filename, e.g., a drizzled mosaic
covering the area around a given grism exposure.
ref_ext : 0
FITS extension of the reference file where to find the image
itself.
seg_file : `None` or str
Segmentation image filename.
shrink_segimage : bool
Do some preprocessing on the segmentation image to speed up the
blotting to the distorted frame of the grism exposures.
verbose : bool
Print verbose information.
cpu_count : int
Use parallelization if > 0. If equal to zero, then use the
maximum number of available cores.
catalog : str
Catalog filename assocated with `seg_file`. These are typically
generated with "SExtractor", but the source of the files
themselves isn't critical.
Attributes
----------
catalog : `~astropy.table.Table`
The table read in with from the above file specified in `catalog`.
FLTs : list
List of `~grizli.model.GrismFLT` objects generated from each of
the files in the `grism_files` list.
grp.N : int
Number of grism files (i.e., `len(FLTs)`.)
"""
N = len(grism_files)
if len(direct_files) != len(grism_files):
direct_files = ['']*N
self.grism_files = grism_files
self.direct_files = direct_files
self.group_name = group_name
# Wavelengths for polynomial fits
self.polyx = polyx
# Read catalog
if catalog:
if isinstance(catalog, str):
self.catalog = utils.GTable.gread(catalog)
else:
self.catalog = catalog
# necessary columns from SExtractor / photutils
pairs = [['NUMBER', 'id'],
['MAG_AUTO', 'mag'],
['MAGERR_AUTO', 'mag_err']]
cols = self.catalog.colnames
for pair in pairs:
if (pair[0] not in cols) & (pair[1] in cols):
self.catalog[pair[0]] = self.catalog[pair[1]]
else:
self.catalog = None
if cpu_count == 0:
cpu_count = mp.cpu_count()
self.FLTs = []
if cpu_count < 0:
# serial
t0_pool = time.time()
for i in range(N):
flt = _loadFLT(self.grism_files[i], sci_extn,
self.direct_files[i], pad, ref_file, ref_ext,
seg_file, verbose, self.catalog, i,
use_jwst_crds)
self.FLTs.append(flt)
t1_pool = time.time()
else:
# Read files in parallel
t0_pool = time.time()
pool = mp.Pool(processes=cpu_count)
results = [pool.apply_async(_loadFLT,
(self.grism_files[i],
sci_extn,
self.direct_files[i],
pad,
ref_file,
ref_ext,
seg_file,
verbose,
self.catalog,
i,
use_jwst_crds)
) for i in range(N)]
pool.close()
pool.join()
for res in results:
flt_i = res.get(timeout=1)
#flt_i.catalog = cat_i
# somehow WCS getting flipped from cd to pc in res.get()???
if flt_i.direct.wcs.wcs.has_pc():
for obj in [flt_i.grism, flt_i.direct]:
obj.get_wcs()
self.FLTs.append(flt_i)
t1_pool = time.time()
# Set conf
for flt_i in self.FLTs:
if not hasattr(flt_i, 'conf'):
flt_i.conf = grismconf.load_grism_config(flt_i.conf_file)
if verbose:
print('Files loaded - {0:.2f} sec.'.format(t1_pool - t0_pool))
@property
def N(self):
return len(self.FLTs)
@property
def Ngrism(self):
"""
dictionary containing number of exposures by grism
"""
# Parse grisms & PAs
Ngrism = {}
for flt in self.FLTs:
if flt.grism.instrument == 'NIRISS':
grism = flt.grism.pupil
else:
grism = flt.grism.filter
if grism not in Ngrism:
Ngrism[grism] = 0
Ngrism[grism] += 1
return Ngrism
@property
def grisms(self):
"""
Available grisms
"""
grisms = list(self.Ngrism.keys())
return grisms
@property
def PA(self):
"""
Available PAs in each grism
"""
_PA = {}
for g in self.Ngrism:
_PA[g] = {}
for i, flt in enumerate(self.FLTs):
if flt.grism.instrument == 'NIRISS':
grism = flt.grism.pupil
else:
grism = flt.grism.filter
PA_i = flt.get_dispersion_PA(decimals=0)
if PA_i not in _PA[grism]:
_PA[grism][PA_i] = []
_PA[grism][PA_i].append(i)
return _PA
[docs] def save_full_data(self, warn=True, verbose=False):
"""Save models and data files for fast regeneration.
The filenames of the outputs are generated from the input grism
exposure filenames with the following:
>>> file = 'ib3701ryq_flt.fits'
>>> sci_extn = 1
>>> new_root = '.{0:02d}.GrismFLT.fits'.format(sci_extn)
>>>
>>> save_file = file.replace('_flt.fits', new_root)
>>> save_file = save_file.replace('_flc.fits', new_root)
>>> save_file = save_file.replace('_cmb.fits', new_root)
>>> save_file = save_file.replace('_rate.fits', new_root)
It will also save data to a `~pickle` file:
>>> pkl_file = save_file.replace('.fits', '.pkl')
Parameters
----------
warn : bool
Print a warning and skip if an output file is already found to
exist.
Notes
-----
The save filename format was changed May 9, 2017 to the format like
`ib3701ryq.01.GrismFLT.fits` from `ib3701ryq_GrismFLT.fits` to both
allow easier filename parsing and also to allow for instruments that
have multiple `SCI` extensions in a single calibrated file
(e.g., ACS and WFC3/UVIS).
"""
for _flt in self.FLTs:
file = _flt.grism_file
if _flt.grism.data is None:
if warn:
print('{0}: Looks like data already saved!'.format(file))
continue
new_root = '.{0:02d}.GrismFLT.fits'.format(_flt.grism.sci_extn)
save_file = file.replace('_flt.fits', new_root)
save_file = save_file.replace('_flc.fits', new_root)
save_file = save_file.replace('_cmb.fits', new_root)
save_file = save_file.replace('_rate.fits', new_root)
save_file = save_file.replace('_elec.fits', new_root)
if (save_file == file) & ('GrismFLT' not in file):
# couldn't build new filename based on the extensions
# so just insert at the end
save_file = file.replace('.fits', new_root)
# Rotate back to detector frame if needed
if _flt.grism.instrument in ['NIRISS', 'NIRCAM']:
_flt.transform_JWST_WFSS(verbose=verbose)
# Save the files
print('Save {0}'.format(save_file))
_flt.save_full_pickle()
# Reload initialized data
_flt.load_from_fits(save_file)
# Rotate to "GrismFLT" frame if needed
if _flt.grism.instrument in ['NIRISS', 'NIRCAM']:
_flt.transform_JWST_WFSS(verbose=verbose)
[docs] def extend(self, new, verbose=True):
"""Add another `GroupFLT` instance to `self`
This function appends the exposures if a separate `GroupFLT` instance
to the current instance. You might do this, for example, if you
generate separate `GroupFLT` instances for different grisms and
reference images with different filters.
"""
import copy
self.FLTs.extend(new.FLTs)
direct_files = copy.copy(self.direct_files)
direct_files.extend(new.direct_files)
self.direct_files = direct_files
grism_files = copy.copy(self.grism_files)
grism_files.extend(new.grism_files)
self.grism_files = grism_files
# self.direct_files.extend(new.direct_files)
# self.grism_files.extend(new.grism_files)
if verbose:
print('Now we have {0:d} FLTs'.format(self.N))
[docs] def compute_single_model(self, id, center_rd=None, mag=-99, size=-1, store=False, spectrum_1d=None, is_cgs=False, get_beams=None, in_place=True, min_size=26, psf_param_dict={}):
"""Compute model spectrum in all exposures
TBD
Parameters
----------
id : type
center_rd : None
mag : type
size : type
store : type
spectrum_1d : type
get_beams : type
in_place : type
Returns
-------
TBD
"""
out_beams = []
for flt in self.FLTs:
if flt.grism.parent_file in psf_param_dict:
psf_params = psf_param_dict[flt.grism.parent_file]
else:
psf_params = None
if center_rd is None:
x = y = None
else:
_rd = np.array(center_rd)[None, :]
x, y = flt.direct.wcs.all_world2pix(_rd, 0).flatten()
status = flt.compute_model_orders(id=id, x=x, y=y, verbose=False,
size=size, compute_size=(size < 0),
mag=mag, in_place=in_place, store=store,
spectrum_1d=spectrum_1d, is_cgs=is_cgs,
get_beams=get_beams, psf_params=psf_params,
min_size=min_size)
out_beams.append(status)
if get_beams:
return out_beams
else:
return True
[docs] def compute_full_model(self, fit_info=None, verbose=True, store=False,
mag_limit=25, coeffs=[1.2, -0.5], cpu_count=0,
is_cgs=False, model_kwargs={'compute_size':True}):
"""Compute continuum models of all sources in an FLT
Parameters
----------
fit_info : dict
verbose : bool
store : bool
mag_limit : float
Faint limit of objects to compute
coeffs : list
Polynomial coefficients of the continuum model
cpu_count : int
Number of CPUs to use for parallel processing. If 0, then get
from `multiprocessing.cpu_count`.
is_cgs : bool
Spectral models are in cgs units
model_kwargs : dict
Keywords to pass to the
`~grizli.model.GrismFLT.compute_model_orders` method of the
`~grizli.model.GrismFLT` objects.
Returns
-------
Sets `object_dispersers` and `model` attributes on items in
`self.FLTs`
"""
if cpu_count <= 0:
cpu_count = np.maximum(mp.cpu_count() - 4, 1)
if fit_info is None:
bright = self.catalog['MAG_AUTO'] < mag_limit
ids = self.catalog['NUMBER'][bright]
mags = self.catalog['MAG_AUTO'][bright]
# Polynomial component
#xspec = np.arange(0.3, 5.35, 0.05)-1
xspec = np.arange(self.polyx[0], self.polyx[1], 0.05)
if len(self.polyx) > 2:
px0 = self.polyx[2]
else:
px0 = 1.0
yspec = [(xspec-px0)**o*coeffs[o] for o in range(len(coeffs))]
xspec = (xspec)*1.e4
yspec = np.sum(yspec, axis=0)
fit_info = OrderedDict()
for id, mag in zip(ids, mags):
fit_info[id] = {'mag': mag, 'spec': [xspec, yspec]}
is_cgs = False
t0_pool = time.time()
# Remove conf
for flt_i in self.FLTs:
if hasattr(flt_i, 'conf'):
delattr(flt_i, 'conf')
pool = mp.Pool(processes=cpu_count)
jobs = [pool.apply_async(_compute_model,
(i, self.FLTs[i], fit_info,
is_cgs, store, model_kwargs))
for i in range(self.N)]
pool.close()
pool.join()
for res in jobs:
i, model, dispersers = res.get(timeout=1)
self.FLTs[i].object_dispersers = dispersers
self.FLTs[i].model = model
# Reload conf
for flt_i in self.FLTs:
if not hasattr(flt_i, 'conf'):
flt_i.conf = grismconf.load_grism_config(flt_i.conf_file)
t1_pool = time.time()
if verbose:
print('Models computed - {0:.2f} sec.'.format(t1_pool - t0_pool))
[docs] def get_beams(self, id, size=10, center_rd=None, beam_id='A',
min_overlap=0.1, min_valid_pix=10, min_mask=0.01,
min_sens=0.08, mask_resid=True, get_slice_header=True,
show_exception=False):
"""Extract 2D spectra "beams" from the GroupFLT exposures.
Parameters
----------
id : int
Catalog ID of the object to extract.
size : int
Half-size of the 2D spectrum to extract, along cross-dispersion
axis.
center_rd : optional, (float, float)
Extract based on RA/Dec rather than catalog ID.
beam_id : type
Name of the order to extract.
min_overlap : float
Fraction of the spectrum along wavelength axis that has one
or more valid pixels.
min_valid_pix : int
Minimum number of valid pixels (`beam.fit_mask == True`) in 2D
spectrum.
min_mask : float
Minimum factor relative to the maximum pixel value of the flat
f-lambda model where the 2D cutout data are considered good.
Passed through to `~grizli.model.BeamCutout`.
min_sens : float
See `~grizli.model.BeamCutout`.
get_slice_header : bool
Passed to `~grizli.model.BeamCutout`.
Returns
-------
beams : list
List of `~grizli.model.BeamCutout` objects.
"""
beams = self.compute_single_model(id,
center_rd=center_rd,
size=size,
store=False,
get_beams=[beam_id])
out_beams = []
for flt, beam in zip(self.FLTs, beams):
try:
out_beam = model.BeamCutout(flt=flt, beam=beam[beam_id],
conf=flt.conf, min_mask=min_mask,
min_sens=min_sens,
mask_resid=mask_resid,
get_slice_header=get_slice_header)
except:
#print('Except: get_beams')
if show_exception:
utils.log_exception(utils.LOGFILE, traceback)
continue
valid = (out_beam.grism['SCI'] != 0)
valid &= out_beam.fit_mask.reshape(out_beam.sh)
hasdata = (valid.sum(axis=0) > 0).sum()
if hasdata*1./out_beam.model.shape[1] < min_overlap:
continue
# Empty direct image?
if out_beam.beam.total_flux == 0:
continue
if out_beam.fit_mask.sum() < min_valid_pix:
continue
out_beams.append(out_beam)
return out_beams
[docs] def refine_list(self, ids=[], mags=[], poly_order=3, mag_limits=[16, 24],
max_coeff=5, ds9=None, verbose=True, fcontam=0.5,
wave=np.linspace(0.2, 2.5e4, 100)):
"""Refine contamination model for list of objects. Loops over `refine`.
Parameters
----------
ids : list
List of object IDs
mags : list
Magnitudes to to along with IDs. If `ids` and `mags` not
specified, then get the ID list from `self.catalog['MAG_AUTO']`.
poly_order : int
Order of the polynomial fit to the spectra.
mag_limits : [float, float]
Magnitude limits of objects to fit from `self.catalog['MAG_AUTO']`
when `ids` and `mags` not set.
max_coeff : float
Fit is considered bad when one of the coefficients is greater
than this value. See `refine`.
ds9 : `~grizli.ds9.DS9`, optional
Display the refined models to DS9 as they are computed.
verbose : bool
Print fit coefficients.
fcontam : float
Contamination weighting parameter.
wave : `~numpy.array`
Wavelength array for the polynomial fit.
Returns
-------
Updates `self.model` in place.
"""
if (len(ids) == 0) | (len(ids) != len(mags)):
bright = ((self.catalog['MAG_AUTO'] < mag_limits[1]) &
(self.catalog['MAG_AUTO'] > mag_limits[0]))
ids = self.catalog['NUMBER'][bright]*1
mags = self.catalog['MAG_AUTO'][bright]*1
so = np.argsort(mags)
ids, mags = ids[so], mags[so]
#wave = np.linspace(0.2,5.4e4,100)
poly_templates = utils.polynomial_templates(wave, order=poly_order, line=False)
for id, mag in zip(ids, mags):
self.refine(id, mag=mag, poly_order=poly_order,
max_coeff=max_coeff, size=30, ds9=ds9,
verbose=verbose, fcontam=fcontam,
templates=poly_templates)
[docs] def refine(self, id, mag=-99, poly_order=3, size=30, ds9=None, verbose=True, max_coeff=2.5, fcontam=0.5, templates=None):
"""Fit polynomial to extracted spectrum of single object to use for contamination model.
Parameters
----------
id : int
Object ID to extract.
mag : float
Object magnitude. Determines which orders to extract; see
`~grizli.model.GrismFLT.compute_model_orders`.
poly_order : int
Order of the polynomial to fit.
size : int
Size of cutout to extract.
ds9 : `~grizli.ds9.DS9`, optional
Display the refined models to DS9 as they are computed.
verbose : bool
Print information about the fit
max_coeff : float
The script computes the implied flux of the polynomial template
at the pivot wavelength of the direct image filters. If this
flux is greater than `max_coeff` times the *observed* flux in the
direct image, then the polynomal fit is considered bad.
fcontam : float
Contamination weighting parameter.
templates : dict, optional
Precomputed template dictionary. If `None` then compute
polynomial templates with order `poly_order`.
Returns
-------
Updates `self.model` in place.
"""
beams = self.get_beams(id, size=size, min_overlap=0.1,
get_slice_header=False, min_mask=0.01,
min_sens=0.01, mask_resid=True)
if len(beams) == 0:
return True
mb = MultiBeam(beams, fcontam=fcontam, min_sens=0.01, sys_err=0.03,
min_mask=0.01, mask_resid=True)
if templates is None:
wave = np.linspace(0.9*mb.wavef.min(), 1.1*mb.wavef.max(), 100)
templates = utils.polynomial_templates(wave, order=poly_order,
line=False)
try:
tfit = mb.template_at_z(z=0, templates=templates, fit_background=True, fitter='lstsq', get_uncertainties=2)
except:
ret = False
return False
scale_coeffs = [tfit['cfit']['poly {0}'.format(i)][0] for i in range(1+poly_order)]
xspec, ypoly = tfit['cont1d'].wave, tfit['cont1d'].flux
# Don't extrapolate
mb_waves = mb.wavef[mb.fit_mask]
mb_clip = (xspec > mb_waves.min()) & (xspec < mb_waves.max())
if mb_clip.sum() > 0:
ypoly[xspec < mb_waves.min()] = ypoly[mb_clip][0]
ypoly[xspec > mb_waves.max()] = ypoly[mb_clip][-1]
# Check where templates inconsistent with broad-band fluxes
xb = [beam.direct.ref_photplam if beam.direct['REF'] is not None else beam.direct.photplam for beam in beams]
obs_flux = np.array([beam.beam.total_flux for beam in beams])
mod_flux = np.polynomial.Polynomial(scale_coeffs)(np.array(xb)/1.e4-1)
nonz = obs_flux != 0
if (np.abs(mod_flux/obs_flux)[nonz].max() > max_coeff) | ((~np.isfinite(mod_flux/obs_flux)[nonz]).sum() > 0) | (np.min(mod_flux[nonz]) < 0) | ((~np.isfinite(ypoly)).sum() > 0):
if verbose:
cstr = ' '.join(['{0:9.2e}'.format(c) for c in scale_coeffs])
print('{0:>5d} mag={1:6.2f} {2} xx'.format(id, mag, cstr))
return True
# Put the refined model into the full-field model
self.compute_single_model(id, mag=mag, size=-1, store=False, spectrum_1d=[xspec, ypoly], is_cgs=True, get_beams=None, in_place=True)
# Display the result?
if ds9:
flt = self.FLTs[0]
mask = flt.grism['SCI'] != 0
ds9.view((flt.grism['SCI'] - flt.model)*mask,
header=flt.grism.header)
if verbose:
cstr = ' '.join(['{0:9.2e}'.format(c) for c in scale_coeffs])
print('{0:>5d} mag={1:6.2f} {2}'.format(id, mag, cstr))
return True
#m2d = mb.reshape_flat(modelf)
# ############
# def old_refine(self, id, mag=-99, poly_order=1, size=30, ds9=None, verbose=True, max_coeff=2.5):
# """TBD
# """
# # Extract and fit beam spectra
# beams = self.get_beams(id, size=size, min_overlap=0.5, get_slice_header=False)
# if len(beams) == 0:
# return True
#
# mb = MultiBeam(beams)
# try:
# A, out_coeffs, chi2, modelf = mb.fit_at_z(poly_order=poly_order, fit_background=True, fitter='lstsq')
# except:
# return False
#
# # Poly template
# scale_coeffs = out_coeffs[mb.N*mb.fit_bg:mb.N*mb.fit_bg+mb.n_poly]
# xspec, yfull = mb.eval_poly_spec(out_coeffs)
#
# # Check where templates inconsistent with broad-band fluxes
# xb = [beam.direct.ref_photplam if beam.direct['REF'] is not None
# else beam.direct.photplam for beam in beams]
# fb = [beam.beam.total_flux for beam in beams]
# mb = np.polynomial.Polynomial(scale_coeffs)(np.array(xb)/1.e4-1)
#
# if (np.abs(mb/fb).max() > max_coeff) | (~np.isfinite(mb/fb).sum() > 0) | (np.min(mb) < 0):
# if verbose:
# print('{0} mag={1:6.2f} {2} xx'.format(id, mag, scale_coeffs))
#
# return True
#
# # Put the refined model into the full-field model
# self.compute_single_model(id, mag=mag, size=-1, store=False, spectrum_1d=[(xspec+1)*1.e4, yfull], is_cgs=True, get_beams=None, in_place=True)
#
# # Display the result?
# if ds9:
# flt = self.FLTs[0]
# mask = flt.grism['SCI'] != 0
# ds9.view((flt.grism['SCI'] - flt.model)*mask,
# header=flt.grism.header)
#
# if verbose:
# print('{0} mag={1:6.2f} {2}'.format(id, mag, scale_coeffs))
#
# return True
# #m2d = mb.reshape_flat(modelf)
[docs] def make_stack(self, id, size=20, target='grism', skip=True, fcontam=1., scale=1, save=True, kernel='point', pixfrac=1, diff=True):
"""Make drizzled 2D stack for a given object
Parameters
----------
id : int
Object ID number.
target : str
Rootname for output files.
skip : bool
If True and the stack PNG file already exists, don't proceed.
fcontam : float
Contamination weighting parameter.
save : bool
Save the figure and FITS HDU to files with names like
>>> img_file = '{0}_{1:05d}.stack.png'.format(target, id)
>>> fits_file = '{0}_{1:05d}.stack.fits'.format(target, id)
diff : bool
Plot residual in final stack panel.
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDU of the stacked spectra.
fig : `~matplotlib.figure.Figure`
Stack figure object.
"""
print(target, id)
if os.path.exists('{0}_{1:05d}.stack.png'.format(target, id)) & skip:
return True
beams = self.get_beams(id, size=size, beam_id='A')
if len(beams) == 0:
print('id = {0}: No beam cutouts available.'.format(id))
return None
mb = MultiBeam(beams, fcontam=fcontam, group_name=target)
hdu, fig = mb.drizzle_grisms_and_PAs(fcontam=fcontam, flambda=False,
size=size, scale=scale,
kernel=kernel, pixfrac=pixfrac,
diff=diff)
if save:
fig.savefig('{0}_{1:05d}.stack.png'.format(target, id))
hdu.writeto('{0}_{1:05d}.stack.fits'.format(target, id),
overwrite=True)
return hdu, fig
[docs] def drizzle_grism_models(self, root='grism_model', kernel='square', scale=0.1, pixfrac=1, make_figure=True, fig_xsize=10, write_ctx=False):
"""
Make model-subtracted drizzled images of each grism / PA
Parameters
----------
root : str
Rootname of the output files.
kernel : str
Drizzle kernel e.g., ('square', 'point').
scale : float
Drizzle `scale` parameter, pixel scale in arcsec.
pixfrac : float
Drizzle "pixfrac".
write_ctx : bool
Save context image as well.
"""
try:
from .utils import drizzle_array_groups
except:
from grizli.utils import drizzle_array_groups
# Loop through grisms and PAs
for g in self.PA:
for pa in self.PA[g]:
idx = self.PA[g][pa]
N = len(idx)
sci_list = []
for i in idx:
grism = self.FLTs[i].grism
if 'MED' in grism.data:
sci_list.append(grism['SCI'] + grism['MED'])
else:
sci_list.append(grism['SCI'])
#sci_list = [self.FLTs[i].grism['SCI'] for i in idx]
clean_list = [self.FLTs[i].grism['SCI']-self.FLTs[i].model
for i in idx]
wht_list = [
(self.FLTs[i].grism['DQ'] == 0) / self.FLTs[i].grism['ERR']**2
for i in idx
]
for i in range(N):
mask = ~np.isfinite(wht_list[i])
wht_list[i][mask] = 0
wcs_list = [self.FLTs[i].grism.wcs for i in idx]
for i, ix in enumerate(idx):
if wcs_list[i]._naxis[0] == 0:
wcs_list[i]._naxis = self.FLTs[ix].grism.sh[::-1]
# Science array
outfile = '{0}-{1}-{2}_grism_sci.fits'.format(root, g.lower(),
pa)
print(outfile)
out = drizzle_array_groups(
sci_list,
wht_list,
wcs_list,
scale=scale,
kernel=kernel,
pixfrac=pixfrac
)
outsci, _, _, outctx, header, outputwcs = out
header['FILTER'] = g
header['PA'] = pa
# Add invidual FLTs to header
for i, ix in enumerate(idx):
header[f'FLT{str(i+1).zfill(5)}'] = self.FLTs[ix].grism_file
pyfits.writeto(outfile, data=outsci, header=header,
overwrite=True, output_verify='fix')
if write_ctx:
pyfits.writeto(outfile.replace('sci', 'ctx'), data=outctx,
header=header, overwrite=True, output_verify='fix')
# Model-subtracted
outfile = '{0}-{1}-{2}_grism_clean.fits'.format(root, g.lower(),
pa)
print(outfile)
out = drizzle_array_groups(
clean_list,
wht_list,
wcs_list,
scale=scale,
kernel=kernel,
pixfrac=pixfrac
)
outsci, _, _, _, header, outputwcs = out
header['FILTER'] = g
header['PA'] = pa
pyfits.writeto(outfile, data=outsci, header=header,
overwrite=True, output_verify='fix')
# Make figure
if make_figure:
with pyfits.open(outfile.replace('clean', 'sci')) as img:
im = img[0].data*1
im[im == 0] = np.nan
sh = im.shape
yp, xp = np.indices(sh)
mask = np.isfinite(im)
xmi = np.maximum(xp[mask].min()-10, 0)
xma = np.minimum(xp[mask].max()+10, sh[1])
ymi = np.maximum(yp[mask].min()-10, 0)
yma = np.minimum(yp[mask].max()+10, sh[0])
xsl = slice(xmi, xma)
ysl = slice(ymi, yma)
_dy = (ysl.stop - ysl.start)
_dx = (xsl.stop - xsl.start)
sh_aspect = _dy / _dx
vmi, vma = -0.05, 0.2
fig = plt.figure(figsize=[fig_xsize,
fig_xsize/2*sh_aspect])
ax = fig.add_subplot(121)
ax.imshow(im[ysl, xsl], origin='lower', cmap='gray_r',
vmin=vmi, vmax=vma)
# Clean
ax = fig.add_subplot(122)
with pyfits.open(outfile) as img:
im = img[0].data*1
im[im == 0] = np.nan
ax.imshow(im[ysl, xsl], origin='lower', cmap='gray_r',
vmin=vmi, vmax=vma)
for ax in fig.axes:
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.axis('off')
fig.tight_layout(pad=0.)
fig.text(0.5, 0.98, outfile.split('_grism')[0], color='k',
bbox=dict(facecolor='w', edgecolor='None'),
ha='center', va='top', transform=fig.transFigure)
fig.savefig(outfile.split('_clean')[0]+'.png',
transparent=True)
plt.close(fig)
[docs] def subtract_sep_background(self, mask_threshold=3, mask_iter=3, bw=256, bh=256, fw=3, fh=3, revert=True, **kwargs):
"""
Remove a 2D background from grism exposures with `sep.Background`
Parameters
----------
mask_threshold : float
S/N threshold to mask sources
"""
import sep
for flt in self.FLTs:
msg = f'subtract_sep_background: {flt.grism.parent_file} '
msg += f' mask_threshold={mask_threshold}'
msg += f' fw,fh={fw},{fh} bw,bh={bw},{bh}'
utils.log_comment(utils.LOGFILE, msg, verbose=True)
sci_i = (flt.grism.data['SCI'] - flt.model).astype(np.float32)*1
if revert & ('BKG' in flt.grism.data):
sci_i += flt.grism.data['BKG']
err_i = flt.grism.data['ERR']
dq_i = flt.grism.data['DQ']
ok = (sci_i != 0) & (err_i > 0) & np.isfinite(err_i+sci_i)
ok &= ((dq_i & 1025) == 0)
sci_i[~ok] = 0
ierr = 1/err_i
ierr[~ok] = 0
bkg = 0
for _iter_i in range(mask_iter):
mask = ~(ok & ((sci_i - bkg)*ierr < mask_threshold))
back = sep.Background(sci_i, mask=mask,
bw=bw, bh=bh, fw=fw, fh=fh)
bkg = back.back()
flt.grism.data['SCI'] -= bkg*ok
flt.grism.data['BKG'] = bkg*ok
flt.grism.header['BKGFW'] = fw, 'sep background fw'
flt.grism.header['BKGFH'] = fh, 'sep background fh'
flt.grism.header['BKGBW'] = bw, 'sep background bw'
flt.grism.header['BKGBH'] = bh, 'sep background bh'
[docs] def drizzle_full_wavelength(self, wave=1.4e4, ref_header=None,
kernel='point', pixfrac=1., verbose=True,
offset=[0, 0], fcontam=0.):
"""Drizzle FLT frames recentered at a specified wavelength
Script computes polynomial coefficients that define the dx and dy
offsets to a specific dispersed wavelengh relative to the reference
position and adds these to the SIP distortion keywords before
drizzling the input exposures to the output frame.
Parameters
----------
wave : float
Reference wavelength to center the output products
ref_header : `~astropy.io.fits.Header`
Reference header for setting the output WCS and image dimensions.
kernel : str, ('square' or 'point')
Drizzle kernel to use
pixfrac : float
Drizzle PIXFRAC (for `kernel` = 'point')
verbose : bool
Print information to terminal
Returns
-------
sci, wht : `~np.ndarray`
Drizzle science and weight arrays with dimensions set in
`ref_header`.
"""
from astropy.modeling import models, fitting
import astropy.wcs as pywcs
# try:
# import drizzle
# if drizzle.__version__ != '1.12.99':
# # Not the fork that works for all input/output arrays
# raise(ImportError)
#
# #print('drizzle!!')
# from drizzle.dodrizzle import dodrizzle
# drizzler = dodrizzle
# dfillval = '0'
# except:
from drizzlepac import adrizzle
adrizzle.log.setLevel('ERROR')
drizzler = adrizzle.do_driz
dfillval = 0
# Quick check now for which grism exposures we should use
if wave < 1.1e4:
use_grism = 'G102'
else:
use_grism = 'G141'
# Get the configuration file
conf = None
for i in range(self.N):
if self.FLTs[i].grism.filter == use_grism:
conf = self.FLTs[i].conf
# Grism not found in list
if conf is None:
return False
# Compute field-dependent dispersion parameters
dydx_0_p = conf.conf['DYDX_A_0']
dydx_1_p = conf.conf['DYDX_A_1']
dldp_0_p = conf.conf['DLDP_A_0']
dldp_1_p = conf.conf['DLDP_A_1']
yp, xp = np.indices((1014, 1014)) # hardcoded for WFC3/IR
sk = 10 # don't need to evaluate at every pixel
dydx_0 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dydx_0_p)
dydx_1 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dydx_1_p)
dldp_0 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dldp_0_p)
dldp_1 = conf.field_dependent(xp[::sk, ::sk], yp[::sk, ::sk], dldp_1_p)
# Inverse pixel offsets from the specified wavelength
dp = (wave - dldp_0)/dldp_1
i_x, i_y = 1, 0 # indexing offsets
dx = dp/np.sqrt(1+dydx_1) + i_x
dy = dydx_0 + dydx_1*dx + i_y
dx += offset[0]
dy += offset[1]
# Compute polynomial coefficients
p_init = models.Polynomial2D(degree=4)
#fit_p = fitting.LevMarLSQFitter()
fit_p = fitting.LinearLSQFitter()
p_dx = fit_p(p_init, xp[::sk, ::sk]-507, yp[::sk, ::sk]-507, -dx)
p_dy = fit_p(p_init, xp[::sk, ::sk]-507, yp[::sk, ::sk]-507, -dy)
# Output WCS
out_wcs = pywcs.WCS(ref_header, relax=True)
out_wcs.pscale = utils.get_wcs_pscale(out_wcs)
# Initialize outputs
shape = (ref_header['NAXIS2'], ref_header['NAXIS1'])
outsci = np.zeros(shape, dtype=np.float32)
outwht = np.zeros(shape, dtype=np.float32)
outctx = np.zeros(shape, dtype=np.int32)
# Loop through exposures
for i in range(self.N):
flt = self.FLTs[i]
if flt.grism.filter != use_grism:
continue
h = flt.grism.header.copy()
# Update SIP coefficients
for j, p in enumerate(p_dx.param_names):
key = 'A_'+p[1:]
if key in h:
h[key] += p_dx.parameters[j]
else:
h[key] = p_dx.parameters[j]
for j, p in enumerate(p_dy.param_names):
key = 'B_'+p[1:]
if key in h:
h[key] += p_dy.parameters[j]
else:
h[key] = p_dy.parameters[j]
line_wcs = pywcs.WCS(h, relax=True)
line_wcs.pscale = utils.get_wcs_pscale(line_wcs)
if not hasattr(line_wcs, 'pixel_shape'):
line_wcs.pixel_shape = line_wcs._naxis1, line_wcs._naxis2
# Science and wht arrays
sci = flt.grism['SCI'] - flt.model
wht = 1/(flt.grism['ERR']**2)
scl = np.exp(-(fcontam*np.abs(flt.model)/flt.grism['ERR']))
wht *= scl
wht[~np.isfinite(wht)] = 0
# Drizzle it
if verbose:
print('Drizzle {0} to wavelength {1:.2f}'.format(flt.grism.parent_file, wave))
drizzler(sci, line_wcs, wht, out_wcs,
outsci, outwht, outctx, 1., 'cps', 1,
wcslin_pscale=line_wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval)
# Done!
return outsci, outwht
[docs] def find_source_along_trace(self, line_ra, line_dec):
"""
Given a sky position in the (artificial) dispersed image frame,
find sources that would disperse to that position
"""
# Find an exposure that contains the position
conf = None
for flt in self.FLTs:
wcs_i = flt.grism.wcs
xw, yw = np.squeeze(wcs_i.all_world2pix([line_ra], [line_dec], 0))
sh = flt.grism.sh
if (xw > 0) & (yw > 0) & (xw < sh[1]) & (yw < sh[0]):
conf = flt.conf
break
if conf is None:
msg = f'No exposures disperse to {line_ra}, {line_dec}'
print(msg)
return None
# TBD
return None
# def replace_direct_image_cutouts(beams_file='', ref_image='gdn-100mas-f160w_drz_sci.fits', interp='poly5', cutout=200, background_func=utils.mode_statistic):
# """
# Replace "REF" extensions in a `beams.fits` file
#
# Parameters
# ----------
# beams_file : str
# Filename of a "beams.fits" file.
#
# ref_image : str or `~astropy.io.fits.HDUList`
# Filename or preloaded FITS file.
#
# interp : str
# Interpolation function to use for `~drizzlepac.astrodrizzle.ablot.do_blot`.
#
# cutout : int
# Make a slice of the `ref_image` with size [-cutout,+cutout] around
# the center position of the desired object before passing to `blot`.
#
# Returns
# -------
# beams_image : `~astropy.io.fits.HDUList`
# Image object with the "REF" extensions filled with the new blotted
# image cutouts.
#
# """
# from drizzlepac.astrodrizzle import ablot
#
# if isinstance(ref_image, pyfits.HDUList):
# ref_im = ref_image
# ref_image_filename = ref_image.filename()
# else:
# ref_im = pyfits.open(ref_image)
# ref_image_filename = ref_image
#
# ref_wcs = pywcs.WCS(ref_im[0].header, relax=True)
# ref_wcs.pscale = utils.get_wcs_pscale(ref_wcs)
#
# ref_photflam = ref_im[0].header['PHOTFLAM']
# ref_data = ref_im[0].data
# dummy_wht = np.ones_like(ref_im[0].data, dtype=np.float32)
#
# beams_image = pyfits.open(beams_file)
#
# beam_ra = beams_image[0].header['RA']
# beam_dec = beams_image[0].header['DEC']
#
# xy = np.asarray(np.round(ref_wcs.all_world2pix([beam_ra], [beam_dec], 0)),dtype=int).flatten()
#
# slx = slice(xy[0]-cutout, xy[0]+cutout)
# sly = slice(xy[1]-cutout, xy[1]+cutout)
#
# bkg_data = []
#
# for ie, ext in enumerate(beams_image):
# if 'EXTNAME' not in ext.header:
# continue
# elif ext.header['EXTNAME'] == 'REF':
# #break
#
# ext.header['REF_FILE'] = ref_image_filename
# for k in ['PHOTFLAM', 'PHOTPLAM']:
# ext.header[k] = ref_im[0].header[k]
#
# the_filter = utils.parse_filter_from_header(ref_im[0].header)
# ext.header['FILTER'] = ext.header['DFILTER'] = the_filter
#
# wcs_file = ext.header['GPARENT'].replace('.fits', '.{0:02}.wcs.fits'.format(ext.header['SCI_EXTN']))
# if os.path.exists(wcs_file):
# wcs_fobj = pyfits.open(wcs_file)
#
# ext_wcs = pywcs.WCS(ext.header, relax=True,
# fobj=wcs_fobj)
# # ext_wcs.pixel_shape = (wcs_fobj[0].header['CRPIX1']*2,
# # wcs_fobj[0].header['CRPIX2']*2)
# # try:
# # ext_wcs.wcs.cd = ext_wcs.wcs.pc
# # delattr(ext_wcs.wcs, 'pc')
# # except:
# # pass
# else:
# ext_wcs = pywcs.WCS(ext.header, relax=True)
#
# ext_wcs.pscale = utils.get_wcs_pscale(ext_wcs)
# blotted = ablot.do_blot(ref_data[sly, slx],
# ref_wcs.slice([sly, slx]),
# ext_wcs, 1, coeffs=True, interp=interp,
# sinscl=1.0, stepsize=10, wcsmap=None)
#
# if background_func is not None:
# seg_data = beams_image[ie+1].data
# msk = seg_data == 0
# #print(msk.shape, blotted.shape, seg_data.shape, ie)
# if msk.sum() > 0:
# if bkg_data is None:
# bkg_data = blotted[msk]
# else:
# bkg_data = np.append(bkg_data, blotted[msk])
#
# if msk.sum() > 0:
# blotted -= background_func(blotted[msk])
#
# ext.data = blotted*ref_photflam
#
# if bkg_data is not None:
# bkg_value = background_func(bkg_data)
# for i in range(self.N):
#
# return beams_image
[docs]class MultiBeam(GroupFitter):
def __init__(self, beams, group_name=None, fcontam=0., psf=False, polyx=[0.3, 2.5], MW_EBV=0., min_mask=0.01, min_sens=0.08, sys_err=0.0, mask_resid=True, verbose=True, replace_direct=None, restore_medfilt=False, **kwargs):
"""Tools for dealing with multiple `~.model.BeamCutout` instances
Parameters
----------
beams : list
List of `~.model.BeamCutout` objects.
group_name : str, None
Rootname to use for saved products. If None, then default to
'group'.
fcontam : float
Factor to use to downweight contaminated pixels. The pixel
inverse variances are scaled by the following weight factor when
evaluating chi-squared of a 2D fit,
`weight = np.exp(-(fcontam*np.abs(contam)*np.sqrt(ivar)))`
where `contam` is the contaminating flux and `ivar` is the initial
pixel inverse variance.
psf : bool
Fit an ePSF model to the direct image to use as the morphological
reference.
MW_EBV : float
Milky way foreground extinction.
min_mask : float
Minimum factor relative to the maximum pixel value of the flat
f-lambda model where the 2D cutout data are considered good.
Passed through to `~grizli.model.BeamCutout`.
min_sens : float
See `~grizli.model.BeamCutout`.
sys_err : float
Systematic error added in quadrature to the pixel variances:
`var_total = var_initial + (beam.sci*sys_err)**2`
Attributes
----------
TBD : type
"""
if group_name is None:
self.group_name = 'group'
else:
self.group_name = group_name
self.fcontam = fcontam
self.polyx = polyx
self.min_mask = min_mask
self.min_sens = min_sens
self.mask_resid = mask_resid
self.sys_err = sys_err
self.restore_medfilt = restore_medfilt
self.Asave = {}
if isinstance(beams, str):
self.load_master_fits(beams, verbose=verbose)
# check if isJWST
isJWST = prep.check_isJWST(beams)
# Auto-generate group_name from filename, e.g.,
# j100140p0130_00237.beams.fits > j100140p0130
if group_name is None:
self.group_name = beams.split('_')[0]
else:
if isinstance(beams[0], str):
# `beams` is list of strings
if 'beams.fits' in beams[0]:
# Master beam files
# check if isJWST
isJWST = prep.check_isJWST(beams[0])
self.load_master_fits(beams[0], verbose=verbose)
for i in range(1, len(beams)):
b_i = MultiBeam(beams[i],
group_name=group_name,
fcontam=fcontam,
psf=psf, polyx=polyx,
MW_EBV=np.maximum(MW_EBV, 0),
sys_err=sys_err,
verbose=verbose,
min_mask=min_mask,
min_sens=min_sens,
mask_resid=mask_resid)
self.extend(b_i)
else:
# List of individual beam.fits files
self.load_beam_fits(beams)
else:
self.beams = beams
self.ra, self.dec = self.beams[0].get_sky_coords()
if MW_EBV < 0:
# Try to get MW_EBV from mastquery.utils
try:
import mastquery.utils
MW_EBV = mastquery.utils.get_mw_dust(self.ra, self.dec)
except:
try:
import mastquery.utils
MW_EBV = mastquery.utils.get_irsa_dust(self.ra, self.dec)
except:
MW_EBV = 0.
self.MW_EBV = MW_EBV
self._set_MW_EBV(MW_EBV)
self._parse_beams(psf=psf)
self.apply_trace_shift()
self.Nphot = 0
self.is_spec = 1
if replace_direct is not None:
self.replace_direct_image_cutouts(**replace_direct)
def _set_MW_EBV(self, MW_EBV, R_V=utils.MW_RV):
"""
Initialize Galactic extinction
Parameters
----------
MW_EBV : float
Local E(B-V)
R_V : float
Relation between specific and total extinction,
``a_v = r_v * ebv``.
"""
for b in self.beams:
beam = b.beam
if beam.MW_EBV != MW_EBV:
beam.MW_EBV = MW_EBV
beam.init_galactic_extinction(MW_EBV, R_V=R_V)
beam.process_config()
b.flat_flam = b.compute_model(in_place=False, is_cgs=True)
@property
def N(self):
return len(self.beams)
@property
def Ngrism(self):
"""
dictionary containing number of exposures by grism
"""
# Parse grisms & PAs
Ngrism = {}
for beam in self.beams:
if beam.grism.instrument == 'NIRISS':
grism = beam.grism.pupil
else:
grism = beam.grism.filter
if grism not in Ngrism:
Ngrism[grism] = 0
Ngrism[grism] += 1
return Ngrism
@property
def grisms(self):
"""
Available grisms
"""
grisms = list(self.Ngrism.keys())
return grisms
@property
def PA(self):
"""
Available PAs in each grism
"""
_PA = {}
for g in self.Ngrism:
_PA[g] = {}
for i, beam in enumerate(self.beams):
if beam.grism.instrument == 'NIRISS':
grism = beam.grism.pupil
else:
grism = beam.grism.filter
PA_i = beam.get_dispersion_PA(decimals=0)
if PA_i in _PA[grism]:
_PA[grism][PA_i].append(i)
else:
_PA[grism][PA_i] = [i]
return _PA
@property
def id(self):
return self.beams[0].id
def _parse_beams(self, psf=False):
"""
Derive properties of the beam list (grism, PA) and initialize
data arrays.
"""
# Use WFC3 ePSF for the fit
self.psf_param_dict = None
if (psf > 0) & (self.beams[0].grism.instrument in ['WFC3', 'ACS']):
self.psf_param_dict = OrderedDict()
for ib, beam in enumerate(self.beams):
if (beam.direct.data['REF'] is not None):
# Use REF extension. scale factors might be wrong
beam.direct.data['SCI'] = beam.direct.data['REF']
new_err = np.ones_like(beam.direct.data['ERR'])
new_err *= utils.nmad(beam.direct.data['SCI'])
beam.direct.data['ERR'] = new_err
beam.direct.filter = beam.direct.ref_filter # 'F160W'
beam.direct.photflam = beam.direct.ref_photflam
beam.init_epsf(yoff=0.0, skip=psf*1, N=4, get_extended=True)
#beam.compute_model = beam.compute_model_psf
#beam.beam.compute_model = beam.beam.compute_model_psf
beam.compute_model(use_psf=True)
m = beam.compute_model(in_place=False)
#beam.modelf = beam.model.flatten()
#beam.model = beam.modelf.reshape(beam.beam.sh_beam)
beam.flat_flam = beam.compute_model(in_place=False,
is_cgs=True)
_p = beam.grism.parent_file
self.psf_param_dict[_p] = beam.beam.psf_params
self._parse_beam_arrays()
def _parse_beam_arrays(self):
"""
"""
self.poly_order = None
self.shapes = [beam.model.shape for beam in self.beams]
self.Nflat = [np.prod(shape) for shape in self.shapes]
self.Ntot = np.sum(self.Nflat)
for b in self.beams:
if hasattr(b, 'xp_mask'):
delattr(b, 'xp_mask')
# Big array of normalized wavelengths (wave / 1.e4 - 1)
self.xpf = np.hstack([np.dot(np.ones((b.beam.sh_beam[0], 1)),
b.beam.lam[None, :]).flatten()/1.e4
for b in self.beams]) - 1
# Flat-flambda model spectra
self.flat_flam = np.hstack([b.flat_flam for b in self.beams])
self.fit_mask = np.hstack([b.fit_mask*b.contam_mask
for b in self.beams])
self.DoF = self.fit_mask.sum()
# systematic error
for i, b in enumerate(self.beams):
if hasattr(b, 'has_sys_err'):
continue
sciu = b.scif.reshape(b.sh)
ivar = 1./(1/b.ivar + (self.sys_err*sciu)**2)
ivar[~np.isfinite(ivar)] = 0
b.ivar = ivar*1
b.ivarf = b.ivar.flatten()
self.ivarf = np.hstack([b.ivarf for b in self.beams])
self.fit_mask &= (self.ivarf >= 0)
self.scif = np.hstack([b.scif for b in self.beams])
self.idf = np.hstack([b.scif*0+ib for ib, b in enumerate(self.beams)])
self.idf = np.asarray(self.idf,dtype=int)
#self.ivarf = 1./(1/self.ivarf + (self.sys_err*self.scif)**2)
self.ivarf[~np.isfinite(self.ivarf)] = 0
self.sivarf = np.sqrt(self.ivarf)
self.wavef = np.hstack([b.wavef for b in self.beams])
self.contamf = np.hstack([b.contam.flatten() for b in self.beams])
weightf = np.exp(-(self.fcontam*np.abs(self.contamf)*self.sivarf))
weightf[~np.isfinite(weightf)] = 0
self.weightf = weightf
self.fit_mask &= self.weightf > 0
self.slices = self._get_slices(masked=False)
self._update_beam_mask()
self.DoF = int((self.weightf*self.fit_mask).sum())
self.Nmask = np.sum([b.fit_mask.sum() for b in self.beams])
# Initialize background fit array
# self.A_bg = np.zeros((self.N, self.Ntot))
# i0 = 0
# for i in range(self.N):
# self.A_bg[i, i0:i0+self.Nflat[i]] = 1.
# i0 += self.Nflat[i]
self.A_bg = self._init_background(masked=False)
self.Asave = {}
self.A_bgm = self._init_background(masked=True)
self.init_poly_coeffs(poly_order=1)
self.ra, self.dec = self.beams[0].get_sky_coords()
[docs] def compute_exptime(self):
"""
Compute number of exposures and total exposure time for each grism
"""
exptime = {}
nexposures = {}
for beam in self.beams:
if beam.grism.instrument == 'NIRISS':
grism = beam.grism.pupil
else:
grism = beam.grism.filter
if grism in exptime:
exptime[grism] += beam.grism.exptime
nexposures[grism] += 1
else:
exptime[grism] = beam.grism.exptime
nexposures[grism] = 1
return nexposures, exptime
[docs] def extend(self, new, verbose=True):
"""Concatenate `~grizli.multifit.MultiBeam` objects
Parameters
----------
new : `~grizli.multifit.MultiBeam`
Beam object containing new beams to add.
verbose : bool
Print summary of the change.
"""
self.beams.extend(new.beams)
self._parse_beams()
if verbose:
print('Add beams: {0}\n Now: {1}'.format(new.Ngrism, self.Ngrism))
[docs] def write_master_fits(self, verbose=True, get_hdu=False, strip=True, include_model=False, get_trace_table=True):
"""Store all beams in a single HDU
TBD
"""
hdu = pyfits.HDUList([pyfits.PrimaryHDU()])
rd = self.beams[0].get_sky_coords()
hdu[0].header['ID'] = (self.id, 'Object ID')
hdu[0].header['RA'] = (rd[0], 'Right Ascension')
hdu[0].header['DEC'] = (rd[1], 'Declination')
exptime = {}
for g in self.Ngrism:
exptime[g] = 0.
count = []
for ib, beam in enumerate(self.beams):
hdu_i = beam.write_fits(get_hdu=True, strip=strip,
include_model=include_model,
get_trace_table=get_trace_table)
hdu.extend(hdu_i[1:])
count.append(len(hdu_i)-1)
hdu[0].header['FILE{0:04d}'.format(ib)] = (beam.grism.parent_file, 'Grism parent file')
hdu[0].header['GRIS{0:04d}'.format(ib)] = (beam.grism.filter, 'Grism element')
hdu[0].header['NEXT{0:04d}'.format(ib)] = (count[-1], 'Number of extensions')
try:
exptime[beam.grism.filter] += beam.grism.header['EXPTIME']
except:
exptime[beam.grism.pupil] += beam.grism.header['EXPTIME']
hdu[0].header['COUNT'] = (self.N, ' '.join(['{0}'.format(c) for c in count]))
for g in self.Ngrism:
hdu[0].header['T_{0}'.format(g)] = (exptime[g], 'Exposure time in grism {0}'.format(g))
if get_hdu:
return hdu
outfile = '{0}_{1:05d}.beams.fits'.format(self.group_name, self.id)
if verbose:
print(outfile)
hdu.writeto(outfile, overwrite=True)
[docs] def load_master_fits(self, beam_file, verbose=True):
"""
Load a "beams.fits" file.
"""
import copy
try:
utils.fetch_acs_wcs_files(beam_file)
except:
pass
if verbose:
print('load_master_fits: {0}'.format(beam_file))
# check if isJWST
isJWST = prep.check_isJWST(beam_file)
hdu = pyfits.open(beam_file, lazy_load_hdus=False)
N = hdu[0].header['COUNT']
Next = np.asarray(hdu[0].header.comments['COUNT'].split(),dtype=int)
i0 = 1
self.beams = []
for i in range(N):
key = 'NEXT{0:04d}'.format(i)
if key in hdu[0].header:
Next_i = hdu[0].header[key]
else:
Next_i = 6 # Assume doesn't have direct SCI/ERR cutouts
# Testing for multiprocessing
if True:
hducopy = hdu[i0:i0+Next_i]
else:
# print('Copy!')
hducopy = pyfits.HDUList([hdu[i].__class__(data=hdu[i].data*1, header=copy.deepcopy(hdu[i].header), name=hdu[i].name) for i in range(i0, i0+Next_i)])
beam = model.BeamCutout(fits_file=hducopy, min_mask=self.min_mask,
min_sens=self.min_sens,
mask_resid=self.mask_resid,
restore_medfilt=self.restore_medfilt)
self.beams.append(beam)
if verbose:
print('{0} {1} {2}'.format(i+1, beam.grism.parent_file, beam.grism.filter))
i0 += Next_i # 6#Next[i]
hdu.close()
[docs] def write_beam_fits(self, verbose=True):
"""TBD
"""
outfiles = []
for beam in self.beams:
root = beam.grism.parent_file.split('.fits')[0]
outfile = beam.write_fits(root)
if verbose:
print('Wrote {0}'.format(outfile))
outfiles.append(outfile)
return outfiles
[docs] def load_beam_fits(self, beam_list, conf=None, verbose=True):
"""TBD
"""
self.beams = []
for file in beam_list:
if verbose:
print(file)
beam = model.BeamCutout(fits_file=file, conf=conf,
min_mask=self.min_mask,
min_sens=self.min_sens,
mask_resid=self.mask_resid)
self.beams.append(beam)
[docs] def replace_segmentation_image_cutouts(self, ref_image='gdn-100mas-f160w_seg.fits'):
"""
Replace "REF" extensions in a `beams.fits` file
Parameters
----------
ref_image : str, `~astropy.io.fits.HDUList`, `~astropy.io.fits.ImageHDU`
Filename or preloaded FITS file.
Returns
-------
beams_image : `~astropy.io.fits.HDUList`
Image object with the "REF" extensions filled with the new blotted
image cutouts.
"""
if isinstance(ref_image, pyfits.HDUList):
ref_data = ref_image[0].data
ref_header = ref_image[0].header
ref_image_filename = ref_image.filename()
elif (isinstance(ref_image, pyfits.ImageHDU) |
isinstance(ref_image, pyfits.PrimaryHDU)):
ref_data = ref_image.data
ref_header = ref_image.header
ref_image_filename = 'HDU'
else:
with pyfits.open(ref_image) as ref_im:
ref_data = ref_im[0].data*1
ref_header = ref_im[0].header.copy()
ref_image_filename = ref_image
ref_wcs = pywcs.WCS(ref_header, relax=True)
ref_wcs.pscale = utils.get_wcs_pscale(ref_wcs)
ref_data = ref_data.astype(np.float32)
for ib in range(self.N):
wcs_copy = self.beams[ib].direct.wcs
if hasattr(wcs_copy, 'idcscale'):
if wcs_copy.idcscale is None:
delattr(wcs_copy, 'idcscale')
in_data, in_wcs, out_wcs = ref_data, ref_wcs, wcs_copy
blot_seg = utils.blot_nearest_exact(ref_data, ref_wcs, wcs_copy,
verbose=True, stepsize=-1,
scale_by_pixel_area=False)
self.beams[ib].beam.set_segmentation(blot_seg)
[docs] def replace_direct_image_cutouts(self, ref_image='gdn-100mas-f160w_drz_sci.fits', ext=0, interp='poly5', cutout=200, background_func=np.median, thumb_labels=None):
"""
Replace "REF" extensions in a `beams.fits` file
Parameters
----------
ref_image : str or `~astropy.io.fits.HDUList`
Filename or preloaded FITS file.
interp : str
Interpolation function to use for `~drizzlepac.astrodrizzle.ablot.do_blot`.
cutout : int
Make a slice of the `ref_image` with size [-cutout,+cutout] around
the center position of the desired object before passing to
`blot`.
background_func : function, None
If not `None`, compute local background with value from
`background_func(ref_image[cutout])`.
Returns
-------
beams_image : `~astropy.io.fits.HDUList`
Image object with the "REF" extensions filled with the new blotted
image cutouts.
"""
from drizzlepac.astrodrizzle import ablot
if isinstance(ref_image, pyfits.HDUList):
ref_data = ref_image[0].data
ref_header = ref_image[0].header
ref_image_filename = ref_image.filename()
elif (isinstance(ref_image, pyfits.ImageHDU) |
isinstance(ref_image, pyfits.PrimaryHDU)):
ref_data = ref_image.data
ref_header = ref_image.header
ref_image_filename = 'HDU'
else:
with pyfits.open(ref_image)[ext] as ref_im:
ref_data = ref_im.data*1
ref_header = ref_im.header.copy()
ref_image_filename = ref_image
if ref_data.dtype not in [np.float32, np.dtype('>f4')]:
ref_data = ref_data.astype(np.float32)
ref_wcs = pywcs.WCS(ref_header, relax=True)
ref_wcs.pscale = utils.get_wcs_pscale(ref_wcs)
if not hasattr(ref_wcs, '_naxis1') & hasattr(ref_wcs, '_naxis'):
ref_wcs._naxis1, ref_wcs._naxis2 = ref_wcs._naxis
if 'PHOTPLAM' in ref_header:
ref_photplam = ref_header['PHOTPLAM']
else:
ref_photplam = 1.
if 'PHOTFLAM' in ref_header:
ref_photflam = ref_header['PHOTFLAM']
else:
ref_photflam = 1.
try:
ref_filter = utils.parse_filter_from_header(ref_header)
except:
ref_filter = 'N/A'
beam_ra, beam_dec = self.ra, self.dec
xy = np.asarray(np.round(ref_wcs.all_world2pix([beam_ra], [beam_dec], 0)),dtype=int).flatten()
sh = ref_data.shape
slx = slice(np.maximum(xy[0]-cutout, 0),
np.minimum(xy[0]+cutout, sh[1]))
sly = slice(np.maximum(xy[1]-cutout, 0),
np.minimum(xy[1]+cutout, sh[0]))
bkg_data = None
for ie in range(self.N):
wcs_copy = self.beams[ie].direct.wcs
if hasattr(wcs_copy, 'idcscale'):
if wcs_copy.idcscale is None:
delattr(wcs_copy, 'idcscale')
if not hasattr(wcs_copy, '_naxis1') & hasattr(wcs_copy, '_naxis'):
wcs_copy._naxis1, wcs_copy._naxis2 = wcs_copy._naxis
blotted = ablot.do_blot(ref_data[sly, slx],
ref_wcs.slice([sly, slx]),
wcs_copy, 1, coeffs=True, interp=interp,
sinscl=1.0, stepsize=10, wcsmap=None)
if background_func is not None:
msk = self.beams[ie].beam.seg == 0
#print(msk.shape, blotted.shape, ie)
if msk.sum() > 0:
if bkg_data is None:
bkg_data = blotted[msk]
else:
bkg_data = np.append(bkg_data, blotted[msk])
if thumb_labels is None:
self.beams[ie].direct.data['REF'] = blotted*ref_photflam
self.beams[ie].direct.ref_photflam = ref_photflam
self.beams[ie].direct.ref_photplam = ref_photplam
self.beams[ie].direct.ref_filter = ref_filter
# self.beams[ie].direct.ref_photflam
self.beams[ie].beam.direct = blotted*ref_photflam
else:
for label in thumb_labels:
self.beams[ie].thumbs[label] = blotted*ref_photflam
if bkg_data is not None:
for ie in range(self.N):
bkg_value = background_func(bkg_data)*ref_photflam
if thumb_labels is None:
self.beams[ie].direct.data['REF'] -= bkg_value
else:
for label in thumb_labels:
self.beams[ie].thumbs[label] -= bkg_value
## Recompute total_flux attribute
for b in self.beams:
b.beam.set_segmentation(b.beam.seg)
[docs] def reshape_flat(self, flat_array):
"""TBD
"""
out = []
i0 = 0
for ib in range(self.N):
im2d = flat_array[i0:i0+self.Nflat[ib]].reshape(self.shapes[ib])
out.append(im2d)
i0 += self.Nflat[ib]
return out
[docs] def init_poly_coeffs(self, flat=None, poly_order=1):
"""TBD
"""
# Already done?
if poly_order < 0:
ok_poly = False
poly_order = 0
else:
ok_poly = True
if poly_order == self.poly_order:
return None
self.poly_order = poly_order
if flat is None:
flat = self.flat_flam
# Polynomial continuum arrays
self.A_poly = np.array([self.xpf**order*flat
for order in range(poly_order+1)])
self.A_poly *= ok_poly
self.n_poly = poly_order + 1
self.x_poly = np.array([(self.beams[0].beam.lam/1.e4-1)**order
for order in range(poly_order+1)])
[docs] def eval_poly_spec(self, coeffs_full):
"""Evaluate polynomial spectrum
"""
xspec = np.arange(self.polyx[0], self.polyx[1], 0.05)
if len(self.polyx) > 2:
px0 = self.polyx[2]
else:
px0 = 1.0
i0 = self.N*self.fit_bg
scale_coeffs = coeffs_full[i0:i0+self.n_poly]
#yspec = [xspec**o*scale_coeffs[o] for o in range(self.poly_order+1)]
yfull = np.polynomial.Polynomial(scale_coeffs)(xspec - px0)
return xspec, yfull
[docs] def compute_model(self, id=None, spectrum_1d=None, is_cgs=False, apply_sensitivity=True, scale=None, reset=True):
"""
Compute the dispersed 2D model for an assumed input spectrum
This is a wrapper around the
`grizli.model.GrismDisperser.compute_model` method, where the
parameters are described.
Nothing returned, but the `model` and `modelf` attributes are
updated on the `~grizli.model.GrismDisperser` subcomponents of the
`beams` list.
"""
for beam in self.beams:
beam.beam.compute_model(id=id, spectrum_1d=spectrum_1d,
is_cgs=is_cgs,
scale=scale, reset=reset,
apply_sensitivity=apply_sensitivity)
beam.modelf = beam.beam.modelf
beam.model = beam.beam.modelf.reshape(beam.beam.sh_beam)
[docs] def compute_model_psf(self, id=None, spectrum_1d=None, is_cgs=False):
"""
Compute the dispersed 2D model for an assumed input spectrum and for
ePSF morphologies
This is a wrapper around the
`grizli.model.GrismDisperser.compute_model_psf` method, where the
parameters are described.
Nothing returned, but the `model` and `modelf` attributes are
updated on the `~grizli.model.GrismDisperser` subcomponents of the
`beams` list.
"""
for beam in self.beams:
beam.beam.compute_model_psf(id=id, spectrum_1d=spectrum_1d,
is_cgs=is_cgs)
beam.modelf = beam.beam.modelf
beam.model = beam.beam.modelf.reshape(beam.beam.sh_beam)
[docs] def fit_at_z(self, z=0., templates={}, fitter='nnls',
fit_background=True, poly_order=0):
"""TBD
"""
try:
import sklearn.linear_model
HAS_SKLEARN = True
except:
HAS_SKLEARN = False
import numpy.linalg
import scipy.optimize
# print 'xxx Init poly'
self.init_poly_coeffs(poly_order=poly_order)
# print 'xxx Init bg'
if fit_background:
self.fit_bg = True
A = np.vstack((self.A_bg, self.A_poly))
else:
self.fit_bg = False
A = self.A_poly*1
NTEMP = len(templates)
A_temp = np.zeros((NTEMP, self.Ntot))
# print 'xxx Load templates'
for i, key in enumerate(templates.keys()):
NTEMP += 1
temp = templates[key] # .zscale(z, 1.)
if hasattr(temp, 'flux_flam'):
# eazy-py Template object
spectrum_1d = [temp.wave*(1+z), temp.flux_flam(z=z)/(1+z)]
else:
spectrum_1d = [temp.wave*(1+z), temp.flux/(1+z)]
if z > 4:
try:
import eazy.igm
igm = eazy.igm.Inoue14()
igmz = igm.full_IGM(z, spectrum_1d[0])
spectrum_1d[1] *= igmz
# print('IGM')
except:
# No IGM
pass
i0 = 0
for ib in range(self.N):
beam = self.beams[ib]
lam_beam = beam.beam.lam_beam
if ((temp.wave.min()*(1+z) > lam_beam.max()) |
(temp.wave.max()*(1+z) < lam_beam.min())):
tmodel = 0.
else:
tmodel = beam.compute_model(spectrum_1d=spectrum_1d,
in_place=False, is_cgs=True) # /beam.beam.total_flux
A_temp[i, i0:i0+self.Nflat[ib]] = tmodel # .flatten()
i0 += self.Nflat[ib]
if NTEMP > 0:
A = np.vstack((A, A_temp))
ok_temp = np.sum(A, axis=1) > 0
out_coeffs = np.zeros(A.shape[0])
# LSTSQ coefficients
# print 'xxx Fitter'
fit_functions = {'lstsq': np.linalg.lstsq, 'nnls': scipy.optimize.nnls}
if fitter in fit_functions:
# 'lstsq':
Ax = A[:, self.fit_mask][ok_temp, :].T
# Weight by ivar
Ax *= np.sqrt(self.ivarf[self.fit_mask][:, np.newaxis])
# print 'xxx lstsq'
#out = numpy.linalg.lstsq(Ax,y)
if fitter == 'lstsq':
y = self.scif[self.fit_mask]
# Weight by ivar
y *= np.sqrt(self.ivarf[self.fit_mask])
try:
out = np.linalg.lstsq(Ax, y, rcond=utils.LSTSQ_RCOND)
except:
print(A.min(), Ax.min(), self.fit_mask.sum(), y.min())
raise ValueError
lstsq_coeff, residuals, rank, s = out
coeffs = lstsq_coeff
if fitter == 'nnls':
if fit_background:
off = 0.04
y = self.scif[self.fit_mask]+off
y *= np.sqrt(self.ivarf[self.fit_mask])
coeffs, rnorm = scipy.optimize.nnls(Ax, y+off)
coeffs[:self.N] -= 0.04
else:
y = self.scif[self.fit_mask]
y *= np.sqrt(self.ivarf[self.fit_mask])
coeffs, rnorm = scipy.optimize.nnls(Ax, y)
# if fitter == 'bounded':
# if fit_background:
# off = 0.04
# y = self.scif[self.fit_mask]+off
# y *= self.ivarf[self.fit_mask]
#
# coeffs, rnorm = scipy.optimize.nnls(Ax, y+off)
# coeffs[:self.N] -= 0.04
# else:
# y = self.scif[self.fit_mask]
# y *= np.sqrt(self.ivarf[self.fit_mask])
#
# coeffs, rnorm = scipy.optimize.nnls(Ax, y)
#
# out = scipy.optimize.minimize(self.eval_trace_shift, shifts, bounds=bounds, args=args, method='Powell', tol=tol)
elif HAS_SKLEARN:
Ax = A[:, self.fit_mask][ok_temp, :].T
y = self.scif[self.fit_mask]
# Wieght by ivar
Ax *= np.sqrt(self.ivarf[self.fit_mask][:, np.newaxis])
y *= np.sqrt(self.ivarf[self.fit_mask])
clf = sklearn.linear_model.LinearRegression()
status = clf.fit(Ax, y)
coeffs = clf.coef_
out_coeffs[ok_temp] = coeffs
modelf = np.dot(out_coeffs, A)
chi2 = np.sum((self.weightf*(self.scif - modelf)**2*self.ivarf)[self.fit_mask])
if fit_background:
poly_coeffs = out_coeffs[self.N:self.N+self.n_poly]
else:
poly_coeffs = out_coeffs[:self.n_poly]
self.y_poly = np.dot(poly_coeffs, self.x_poly)
# x_poly = self.x_poly[1,:]+1 = self.beams[0].beam.lam/1.e4
return A, out_coeffs, chi2, modelf
[docs] def parse_fit_outputs(self, z, templates, coeffs_full, A):
"""Parse output from `fit_at_z`.
Parameters
----------
z : float
Redshift at which to evaluate the fits.
templates : list of `~grizli.utils.SpectrumTemplate` objects
Generated with, e.g., `~grizli.utils.load_templates`.
coeffs_full : `~np.ndarray`
Template fit coefficients
A : `~np.ndarray`
Matrix generated for fits and used for computing model 2D spectra:
>>> model_flat = np.dot(coeffs_full, A)
>>> # mb = MultiBeam(...)
>>> all_models = mb.reshape_flat(model_flat)
>>> m0 = all_models[0] # model for mb.beams[0]
Returns
-------
line_flux : dict
Line fluxes and uncertainties, in cgs units (erg/s/cm2)
covar : `~np.ndarray`
Covariance matrix for the fit coefficients
cont1d, line1d, model1d : `~grizli.utils.SpectrumTemplate`
Best-fit continuum, line, and full (continuum + line) templates
model_continuum : `~np.ndarray`
Flat array of the best fit 2D continuum
"""
from collections import OrderedDict
# Covariance matrix for line flux uncertainties
Ax = A[:, self.fit_mask]
ok_temp = (np.sum(Ax, axis=1) > 0) & (coeffs_full != 0)
Ax = Ax[ok_temp, :].T*1 # A[:, self.fit_mask][ok_temp,:].T
Ax *= np.sqrt(self.ivarf[self.fit_mask][:, np.newaxis])
try:
#covar = np.matrix(np.dot(Ax.T, Ax)).I
covar = utils.safe_invert(np.dot(Ax.T, Ax))
covard = np.sqrt(covar.diagonal())
except:
N = ok_temp.sum()
covar = np.zeros((N, N))
covard = np.zeros(N) # -1.
covar_full = utils.fill_masked_covar(covar, ok_temp)
# Random draws from covariance matrix
# draws = np.random.multivariate_normal(coeffs_full[ok_temp], covar, size=500)
line_flux_err = coeffs_full*0.
line_flux_err[ok_temp] = covard
# Continuum fit
mask = np.isfinite(coeffs_full)
for i, key in enumerate(templates.keys()):
if key.startswith('line'):
mask[self.N*self.fit_bg+self.n_poly+i] = False
model_continuum = np.dot(coeffs_full*mask, A)
self.model_continuum = self.reshape_flat(model_continuum)
# model_continuum.reshape(self.beam.sh_beam)
# 1D spectrum
# Polynomial component
xspec, yspec = self.eval_poly_spec(coeffs_full)
model1d = utils.SpectrumTemplate(xspec*1.e4, yspec)
cont1d = model1d*1
i0 = self.fit_bg*self.N + self.n_poly
line_flux = OrderedDict()
fscl = 1. # self.beams[0].beam.total_flux/1.e-17
line1d = OrderedDict()
for i, key in enumerate(templates.keys()):
temp_i = templates[key].zscale(z, coeffs_full[i0+i])
model1d += temp_i
if not key.startswith('line'):
cont1d += temp_i
else:
line1d[key.split()[1]] = temp_i
line_flux[key.split()[1]] = np.array([coeffs_full[i0+i]*fscl,
line_flux_err[i0+i]*fscl])
return line_flux, covar_full, cont1d, line1d, model1d, model_continuum
[docs] def fit_stars(self, poly_order=1, fitter='nnls', fit_background=True,
verbose=True, make_figure=True, zoom=None,
delta_chi2_threshold=0.004, zr=0, dz=0, fwhm=0,
prior=None, templates={}, figsize=[8, 5],
fsps_templates=False):
"""TBD
"""
# Polynomial fit
out = self.fit_at_z(z=0., templates={}, fitter='lstsq',
poly_order=3,
fit_background=fit_background)
A, coeffs, chi2_poly, model_2d = out
# Star templates
templates = utils.load_templates(fwhm=fwhm, stars=True)
NTEMP = len(templates)
key = list(templates)[0]
temp_i = {key: templates[key]}
out = self.fit_at_z(z=0., templates=temp_i, fitter=fitter,
poly_order=poly_order,
fit_background=fit_background)
A, coeffs, chi2, model_2d = out
chi2 = np.zeros(NTEMP)
coeffs = np.zeros((NTEMP, coeffs.shape[0]))
chi2min = 1e30
iz = 0
best = key
for i, key in enumerate(list(templates)):
temp_i = {key: templates[key]}
out = self.fit_at_z(z=0., templates=temp_i,
fitter=fitter, poly_order=poly_order,
fit_background=fit_background)
A, coeffs[i, :], chi2[i], model_2d = out
if chi2[i] < chi2min:
iz = i
chi2min = chi2[i]
best = key
if verbose:
print(utils.NO_NEWLINE + ' {0} {1:9.1f} ({2})'.format(key, chi2[i], best))
# Best-fit
temp_i = {best: templates[best]}
out = self.fit_at_z(z=0., templates=temp_i,
fitter=fitter, poly_order=poly_order,
fit_background=fit_background)
A, coeffs_full, chi2_best, model_full = out
# Continuum fit
mask = np.isfinite(coeffs_full)
for i, key in enumerate(templates.keys()):
if key.startswith('line'):
mask[self.N*self.fit_bg+self.n_poly+i] = False
model_continuum = np.dot(coeffs_full*mask, A)
self.model_continuum = self.reshape_flat(model_continuum)
# model_continuum.reshape(self.beam.sh_beam)
# 1D spectrum
# xspec = np.arange(0.3, 2.35, 0.05)-1
# scale_coeffs = coeffs_full[self.N*self.fit_bg:
# self.N*self.fit_bg+self.n_poly]
#
# yspec = [xspec**o*scale_coeffs[o] for o in range(self.poly_order+1)]
xspec, yspec = self.eval_poly_spec(coeffs_full)
model1d = utils.SpectrumTemplate(xspec*1.e4, yspec)
cont1d = model1d*1
i0 = self.fit_bg*self.N + self.n_poly
line_flux = OrderedDict()
fscl = 1. # self.beams[0].beam.total_flux/1.e-17
temp_i = templates[best].zscale(0, coeffs_full[i0])
model1d += temp_i
cont1d += temp_i
fit_data = OrderedDict()
fit_data['poly_order'] = poly_order
fit_data['fwhm'] = 0
fit_data['zbest'] = np.argmin(chi2)
fit_data['chibest'] = chi2_best
fit_data['chi_poly'] = chi2_poly
fit_data['zgrid'] = np.arange(NTEMP)
fit_data['prior'] = 1
fit_data['A'] = A
fit_data['coeffs'] = coeffs
fit_data['chi2'] = chi2
fit_data['DoF'] = self.DoF
fit_data['model_full'] = model_full
fit_data['coeffs_full'] = coeffs_full
fit_data['line_flux'] = {}
#fit_data['templates_full'] = templates
fit_data['model_cont'] = model_continuum
fit_data['model1d'] = model1d
fit_data['cont1d'] = cont1d
# return fit_data
fig = None
if make_figure:
fig = self.show_redshift_fit(fit_data)
# fig.savefig('fit.pdf')
return fit_data, fig
[docs] def fit_redshift(self, prior=None, poly_order=1, fwhm=1200,
make_figure=True, zr=None, dz=None, verbose=True,
fit_background=True, fitter='nnls',
delta_chi2_threshold=0.004, zoom=True,
line_complexes=True, templates={}, figsize=[8, 5],
fsps_templates=False):
"""TBD
"""
from numpy.polynomial import Polynomial
if zr is None:
zr = [0.65, 1.6]
if dz is None:
dz = [0.005, 0.0004]
if zr in [0]:
stars = True
zr = [0, 0.01]
fitter = 'nnls'
else:
stars = False
zgrid = utils.log_zgrid(zr, dz=dz[0])
NZ = len(zgrid)
# Polynomial fit
out = self.fit_at_z(z=0., templates={}, fitter='lstsq',
poly_order=3,
fit_background=fit_background)
A, coeffs, chi2_poly, model_2d = out
# Set up for template fit
if templates == {}:
templates = utils.load_templates(fwhm=fwhm, stars=stars, line_complexes=line_complexes, fsps_templates=fsps_templates)
else:
if verbose:
print('User templates! N={0} \n'.format(len(templates)))
NTEMP = len(templates)
out = self.fit_at_z(z=0., templates=templates, fitter=fitter,
poly_order=poly_order,
fit_background=fit_background)
A, coeffs, chi2, model_2d = out
chi2 = np.zeros(NZ)
coeffs = np.zeros((NZ, coeffs.shape[0]))
chi2min = 1e30
iz = 0
for i in range(NZ):
out = self.fit_at_z(z=zgrid[i], templates=templates,
fitter=fitter, poly_order=poly_order,
fit_background=fit_background)
A, coeffs[i, :], chi2[i], model_2d = out
if chi2[i] < chi2min:
iz = i
chi2min = chi2[i]
if verbose:
print(utils.NO_NEWLINE + ' {0:.4f} {1:9.1f} ({2:.4f})'.format(zgrid[i], chi2[i], zgrid[iz]))
print('First iteration: z_best={0:.4f}\n'.format(zgrid[iz]))
# peaks
# chi2nu = (chi2.min()-chi2)/self.DoF
# indexes = utils.find_peaks((chi2nu+delta_chi2_threshold)*(chi2nu > - delta_chi2_threshold), threshold=0.3, min_dist=21)
chi2_rev = (chi2_poly - chi2)/self.DoF
if chi2_poly < (chi2.min() + 9):
chi2_rev = (chi2.min() + 16 - chi2)/self.DoF
chi2_rev[chi2_rev < 0] = 0
indexes = utils.find_peaks(chi2_rev, threshold=0.4, min_dist=9)
num_peaks = len(indexes)
if False:
plt.plot(zgrid, (chi2-chi2.min()) / self.DoF)
plt.scatter(zgrid[indexes], (chi2-chi2.min())[indexes] / self.DoF, color='r')
# delta_chi2 = (chi2.max()-chi2.min())/self.DoF
# if delta_chi2 > delta_chi2_threshold:
if (num_peaks > 0) & (not stars) & zoom:
zgrid_zoom = []
for ix in indexes:
if (ix > 0) & (ix < len(chi2)-1):
p = Polynomial.fit(zgrid[ix-1:ix+2], chi2[ix-1:ix+2], deg=2)
zi = p.deriv().roots()[0]
chi_i = p(zi)
zgrid_zoom.extend(np.arange(zi-2*dz[0],
zi+2*dz[0]+dz[1]/10., dz[1]))
# zgrid_zoom = utils.zoom_zgrid(zgrid, chi2/self.DoF,
# threshold=delta_chi2_threshold,
# factor=dz[0]/dz[1])
NZOOM = len(zgrid_zoom)
chi2_zoom = np.zeros(NZOOM)
coeffs_zoom = np.zeros((NZOOM, coeffs.shape[1]))
iz = 0
chi2min = 1.e30
for i in range(NZOOM):
out = self.fit_at_z(z=zgrid_zoom[i], templates=templates,
fitter=fitter, poly_order=poly_order,
fit_background=fit_background)
A, coeffs_zoom[i, :], chi2_zoom[i], model_2d = out
if chi2_zoom[i] < chi2min:
chi2min = chi2_zoom[i]
iz = i
if verbose:
print(utils.NO_NEWLINE+'- {0:.4f} {1:9.1f} ({2:.4f}) {3:d}/{4:d}'.format(zgrid_zoom[i], chi2_zoom[i], zgrid_zoom[iz], i+1, NZOOM))
zgrid = np.append(zgrid, zgrid_zoom)
chi2 = np.append(chi2, chi2_zoom)
coeffs = np.append(coeffs, coeffs_zoom, axis=0)
so = np.argsort(zgrid)
zgrid = zgrid[so]
chi2 = chi2[so]
coeffs = coeffs[so, :]
if prior is not None:
#print('\n\nPrior!\n\n', chi2.min(), prior[1].min())
interp_prior = np.interp(zgrid, prior[0], prior[1])
chi2 += interp_prior
else:
interp_prior = None
print(' Zoom iteration: z_best={0:.4f}\n'.format(zgrid[np.argmin(chi2)]))
# Best redshift
if not stars:
templates = utils.load_templates(line_complexes=False, fwhm=fwhm, fsps_templates=fsps_templates)
zbest = zgrid[np.argmin(chi2)]
ix = np.argmin(chi2)
chibest = chi2.min()
# Fit parabola
if (ix > 0) & (ix < len(chi2)-1):
p = Polynomial.fit(zgrid[ix-1:ix+2], chi2[ix-1:ix+2], deg=2)
zbest = p.deriv().roots()[0]
chibest = p(zbest)
out = self.fit_at_z(z=zbest, templates=templates,
fitter=fitter, poly_order=poly_order,
fit_background=fit_background)
A, coeffs_full, chi2_best, model_full = out
# Parse results
out2 = self.parse_fit_outputs(zbest, templates, coeffs_full, A)
line_flux, covar, cont1d, line1d, model1d, model_continuum = out2
# Output dictionary with fit parameters
fit_data = OrderedDict()
fit_data['poly_order'] = poly_order
fit_data['fwhm'] = fwhm
fit_data['zbest'] = zbest
fit_data['chibest'] = chibest
fit_data['chi_poly'] = chi2_poly
fit_data['zgrid'] = zgrid
fit_data['prior'] = interp_prior
fit_data['A'] = A
fit_data['coeffs'] = coeffs
fit_data['chi2'] = chi2
fit_data['DoF'] = self.DoF
fit_data['model_full'] = model_full
fit_data['coeffs_full'] = coeffs_full
fit_data['covar'] = covar
fit_data['line_flux'] = line_flux
#fit_data['templates_full'] = templates
fit_data['model_cont'] = model_continuum
fit_data['model1d'] = model1d
fit_data['cont1d'] = cont1d
fit_data['line1d'] = line1d
# return fit_data
fig = None
if make_figure:
fig = self.show_redshift_fit(fit_data, figsize=figsize)
# fig.savefig('fit.pdf')
return fit_data, fig
[docs] def run_individual_fits(self, z=0, templates={}):
"""Run template fits on each *exposure* individually to evaluate
variance in line and continuum fits.
Parameters
----------
z : float
Redshift at which to evaluate the fit
templates : list of `~grizli.utils.SpectrumTemplate` objects
Generated with, e.g., `load_templates`.
Returns
-------
line_flux, line_err : dict
Dictionaries with the measured line fluxes and uncertainties for
each exposure fit.
coeffs_list : `~np.ndarray` [Nbeam x Ntemplate]
Raw fit coefficients
chi2_list, DoF_list : `~np.ndarray` [Nbeam]
Chi-squared and effective degrees of freedom for each separate fit
"""
# Fit on the full set of beams
out = self.fit_at_z(z=z, templates=templates,
fitter='nnls', poly_order=self.poly_order,
fit_background=self.fit_bg)
A, coeffs_full, chi2_best, model_full = out
out2 = self.parse_fit_outputs(z, templates, coeffs_full, A)
line, covar, cont1d, line1d, model1d, model_continuum = out2
NB, NTEMP = len(self.beams), len(templates)
# Outputs
coeffs_list = np.zeros((NB, NTEMP))
chi2_list = np.zeros(NB)
DoF_list = np.zeros(NB)
line_flux = OrderedDict()
line_err = OrderedDict()
line_keys = list(line.keys())
for k in line_keys:
line_flux[k] = np.zeros(NB)
line_err[k] = np.zeros(NB)
# Generate separate MultiBeam objects for each individual beam
for i, b in enumerate(self.beams):
b_i = MultiBeam([b], fcontam=self.fcontam,
group_name=self.group_name)
out_i = b_i.fit_at_z(z=z, templates=templates,
fitter='nnls', poly_order=self.poly_order,
fit_background=self.fit_bg)
A_i, coeffs_i, chi2_i, model_full_i = out_i
# Parse fit information from individual fits
out2 = b_i.parse_fit_outputs(z, templates, coeffs_i, A_i)
line_i, covar_i, cont1d_i, line1d_i, model1d_i, model_continuum_i = out2
for k in line_keys:
line_flux[k][i] = line_i[k][0]
line_err[k][i] = line_i[k][1]
coeffs_list[i, :] = coeffs_i[-NTEMP:]
chi2_list[i] = chi2_i
DoF_list[i] = b_i.DoF
return line_flux, line_err, coeffs_list, chi2_list, DoF_list
[docs] def show_redshift_fit(self, fit_data, plot_flambda=True, figsize=[8, 5]):
"""TBD
"""
import matplotlib.gridspec
gs = matplotlib.gridspec.GridSpec(2, 1, height_ratios=[0.6, 1])
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(gs[0])
c2min = fit_data['chi2'].min()
scale_pz = True
if scale_pz:
scale_nu = c2min/self.DoF
scl_label = '_s'
else:
scale_nu = 1.
scl_label = ''
#axz.plot(z, (chi2-chi2.min())/scale_nu, color='k')
#ax.plot(fit_data['zgrid'], fit_data['chi2']/self.DoF)
ax.plot(fit_data['zgrid'], (fit_data['chi2']-c2min)/scale_nu)
ax.set_xlabel('z')
ax.set_ylabel(r'$\chi^2_\nu$, $\nu$={0:d}'.format(self.DoF))
ax.set_ylim(-4, 27)
ax.set_ylabel(r'$\Delta\chi^2{2}$ ({0:.0f}/$\nu$={1:d})'.format(c2min, self.DoF, scl_label))
ax.set_yticks([1, 4, 9, 16, 25])
# for delta in [1,4,9]:
# ax.plot(fit_data['zgrid'],
# fit_data['zgrid']*0.+(c2min+delta)/self.DoF,
# color='{0:.2f}'.format(1-delta*1./10))
ax.plot(fit_data['zgrid'], (fit_data['chi2']*0+fit_data['chi_poly']-c2min)/scale_nu, color='b', linestyle='--', alpha=0.8)
ax.set_xlim(fit_data['zgrid'].min(), fit_data['zgrid'].max())
ax.grid()
ax.set_title(r'ID = {0:d}, $z_\mathrm{{grism}}$={1:.4f}'.format(self.beams[0].id, fit_data['zbest']))
ax = fig.add_subplot(gs[1])
ymax = 0
ymin = 1e10
continuum_fit = self.reshape_flat(fit_data['model_cont'])
line_fit = self.reshape_flat(fit_data['model_full'])
grisms = self.Ngrism.keys()
wfull = {}
ffull = {}
efull = {}
for grism in grisms:
wfull[grism] = []
ffull[grism] = []
efull[grism] = []
for ib in range(self.N):
beam = self.beams[ib]
clean = beam.grism['SCI'] - beam.contam
if self.fit_bg:
bg_i = fit_data['coeffs_full'][ib]
clean -= bg_i # background
else:
bg_i = 0.
#ivar = 1./(1./beam.ivar + self.fcontam*beam.contam)
#ivar[~np.isfinite(ivar)] = 0
# New weight scheme
ivar = beam.ivar
weight = np.exp(-(self.fcontam*np.abs(beam.contam)*np.sqrt(ivar)))
wave, flux, err = beam.beam.optimal_extract(clean,
ivar=ivar,
weight=weight)
mwave, mflux, merr = beam.beam.optimal_extract(line_fit[ib]-bg_i,
ivar=ivar,
weight=weight)
flat = beam.flat_flam.reshape(beam.beam.sh_beam)
wave, fflux, ferr = beam.beam.optimal_extract(flat, ivar=ivar,
weight=weight)
if plot_flambda:
ok = beam.beam.sensitivity > 0.1*beam.beam.sensitivity.max()
wave = wave[ok]
fscl = 1./1.e-19 # beam.beam.total_flux/1.e-17
flux = (flux*fscl/fflux)[ok]*beam.beam.scale
err = (err*fscl/fflux)[ok]
mflux = (mflux*fscl/fflux)[ok]*beam.beam.scale
ylabel = r'$f_\lambda\,/\,10^{-19}\,\mathrm{cgs}$'
else:
ylabel = 'flux (e-/s)'
scl_region = np.isfinite(mflux)
if scl_region.sum() == 0:
continue
# try:
# okerr = np.isfinite(err) #& (np.abs(flux/err) > 0.2) & (err != 0)
# med_err = np.median(err[okerr])
#
# ymax = np.maximum(ymax,
# (mflux[scl_region][2:-2] + med_err).max())
# ymin = np.minimum(ymin,
# (mflux[scl_region][2:-2] - med_err).min())
# except:
# continue
#okerr = (err != 0) & (np.abs(flux/err) > 0.2)
okerr = np.isfinite(err)
ax.errorbar(wave[okerr]/1.e4, flux[okerr], err[okerr], alpha=0.15+0.2*(self.N <= 2), linestyle='None', marker='.', color='{0:.2f}'.format(ib*0.5/self.N), zorder=1)
ax.plot(wave[okerr]/1.e4, mflux[okerr], color='r', alpha=0.5, zorder=3)
if beam.grism.instrument == 'NIRISS':
grism = beam.grism.pupil
else:
grism = beam.grism.filter
# for grism in grisms:
wfull[grism] = np.append(wfull[grism], wave[okerr])
ffull[grism] = np.append(ffull[grism], flux[okerr])
efull[grism] = np.append(efull[grism], err[okerr])
# Scatter direct image flux
if beam.direct.ref_photplam is None:
ax.scatter(beam.direct.photplam/1.e4, beam.beam.total_flux/1.e-19, marker='s', edgecolor='k', color=GRISM_COLORS[grism], alpha=0.2, zorder=100, s=100)
else:
ax.scatter(beam.direct.ref_photplam/1.e4, beam.beam.total_flux/1.e-19, marker='s', edgecolor='k', color=GRISM_COLORS[grism], alpha=0.2, zorder=100, s=100)
for grism in grisms:
if self.Ngrism[grism] > 1:
# binned
okb = (np.isfinite(wfull[grism]) & np.isfinite(ffull[grism]) &
np.isfinite(efull[grism]))
so = np.argsort(wfull[grism][okb])
var = efull[grism]**2
N = int(np.ceil(self.Ngrism[grism]/2)*2)*2
kernel = np.ones(N, dtype=float)/N
wht = 1/var[okb][so]
fbin = nd.convolve(ffull[grism][okb][so]*wht, kernel)[N//2::N]
wbin = nd.convolve(wfull[grism][okb][so]*wht, kernel)[N//2::N]
#vbin = nd.convolve(var[okb][so], kernel**2)[N//2::N]
wht_bin = nd.convolve(wht, kernel)[N//2::N]
vbin = nd.convolve(wht, kernel**2)[N//2::N]/wht_bin**2
fbin /= wht_bin
wbin /= wht_bin
#vbin = 1./wht_bin
ax.errorbar(wbin/1.e4, fbin, np.sqrt(vbin), alpha=0.8,
linestyle='None', marker='.',
color=GRISM_COLORS[grism], zorder=2)
med_err = np.median(np.sqrt(vbin))
ymin = np.minimum(ymin, (fbin-2*med_err).min())
ymax = np.maximum(ymax, (fbin+2*med_err).max())
ymin = np.maximum(0, ymin)
ax.set_ylim(ymin - 0.2*np.abs(ymax), 1.3*ymax)
xmin, xmax = 1.e5, 0
for g in GRISM_LIMITS:
if g in grisms:
xmin = np.minimum(xmin, GRISM_LIMITS[g][0])
xmax = np.maximum(xmax, GRISM_LIMITS[g][1])
# print g, xmin, xmax
ax.set_xlim(xmin, xmax)
ax.semilogx(subsx=[xmax])
# axc.set_xticklabels([])
# axc.set_xlabel(r'$\lambda$')
#axc.set_ylabel(r'$f_\lambda \times 10^{-19}$')
from matplotlib.ticker import MultipleLocator
ax.xaxis.set_major_locator(MultipleLocator(0.1))
labels = np.arange(np.ceil(xmin*10), np.ceil(xmax*10))/10.
ax.set_xticks(labels)
ax.set_xticklabels(labels)
ax.grid()
# Label
ax.text(0.03, 1.03, ('{0}'.format(self.Ngrism)).replace('\'', '').replace('{', '').replace('}', ''), ha='left', va='bottom', transform=ax.transAxes, fontsize=10)
#ax.plot(wave/1.e4, wave/1.e4*0., linestyle='--', color='k')
ax.hlines(0, xmin, xmax, linestyle='--', color='k')
ax.set_xlabel(r'$\lambda$')
ax.set_ylabel(ylabel)
gs.tight_layout(fig, pad=0.1)
return fig
[docs] def drizzle_segmentation(self, wcsobj=None, kernel='square', pixfrac=1, verbose=False):
"""
Drizzle segmentation image from individual `MultiBeam.beams`.
Parameters
----------
wcsobj: `~astropy.wcs.WCS` or `~astropy.io.fits.Header`
Output WCS.
kernel: e.g., 'square', 'point', 'gaussian'
Drizzle kernel, see `~drizzlepac.adrizzle.drizzle`.
pixfrac: float
Drizzle 'pixfrac', see `~drizzlepac.adrizzle.drizzle`.
verbose: bool
Print status messages.
Returns
-------
drizzled_segm: `~numpy.ndarray`, type `~numpy.int64`.
Drizzled segmentation image, with image dimensions and
WCS defined in `wcsobj`.
"""
import numpy as np
import astropy.wcs as pywcs
import astropy.io.fits as pyfits
try:
from . import utils
except:
from grizli import multifit, utils
all_ids = [np.unique(beam.beam.seg) for beam in self.beams]
all_ids = np.unique(np.hstack(all_ids))[1:]
if isinstance(wcsobj, pyfits.Header):
wcs = pywcs.WCS(wcsobj)
wcs.pscale = utils.get_wcs_pscale(wcs)
else:
wcs = wcsobj
if not hasattr(wcs, 'pscale'):
wcs.pscale = utils.get_wcs_pscale(wcs)
if verbose:
print('Drizzle ID={0:.0f} (primary)'.format(self.id))
drizzled_segm = self.drizzle_segmentation_id(id=self.id, wcsobj=wcsobj, kernel=kernel, pixfrac=pixfrac, verbose=verbose)
for id in all_ids:
if int(id) == self.id:
continue
if verbose:
print('Drizzle ID={0:.0f}'.format(id))
dseg_i = self.drizzle_segmentation_id(id=id, wcsobj=wcsobj, kernel=kernel, pixfrac=pixfrac, verbose=False)
new_seg = drizzled_segm == 0
drizzled_segm[new_seg] = dseg_i[new_seg]
return drizzled_segm
[docs] def drizzle_segmentation_id(self, id=None, wcsobj=None, kernel='square', pixfrac=1, verbose=True):
"""
Drizzle segmentation image for a single ID
"""
import numpy as np
import astropy.wcs as pywcs
import astropy.io.fits as pyfits
try:
from . import utils
except:
from grizli import multifit, utils
# Can be either a header or WCS object
if isinstance(wcsobj, pyfits.Header):
wcs = pywcs.WCS(wcsobj)
wcs.pscale = utils.get_wcs_pscale(wcs)
else:
wcs = wcsobj
if not hasattr(wcs, 'pscale'):
wcs.pscale = utils.get_wcs_pscale(wcs)
if id is None:
id = self.id
sci_list = [(beam.beam.seg == id)*1. for beam in self.beams]
wht_list = [np.isfinite(beam.beam.seg)*1. for beam in self.beams]
wcs_list = [beam.direct.wcs for beam in self.beams]
out = utils.drizzle_array_groups(
sci_list,
wht_list,
wcs_list,
outputwcs=wcs,
scale=0.1,
kernel=kernel,
pixfrac=pixfrac,
verbose=verbose
)
drizzled_segm = (out[0] > 0)*id
return drizzled_segm
[docs] def drizzle_fit_lines(self, fit, pline, force_line=['Ha+NII', 'Ha', 'OIII', 'Hb', 'OII'], save_fits=True, mask_lines=True, mask_sn_limit=3, mask_4959=True, verbose=True, include_segmentation=True, get_ir_psfs=True, min_line_sn=4):
"""
TBD
"""
line_wavelengths, line_ratios = utils.get_line_wavelengths()
hdu_full = []
saved_lines = []
if ('cfit' in fit) & mask_4959:
if 'line OIII' in fit['templates']:
t_o3 = utils.load_templates(fwhm=fit['templates']['line OIII'].fwhm, line_complexes=False, stars=False, full_line_list=['OIII-4959'], continuum_list=[], fsps_templates=False)
if 'zbest' in fit:
z_driz = fit['zbest']
else:
z_driz = fit['z']
if 'line_flux' in fit:
line_flux_dict = fit['line_flux']
else:
line_flux_dict = OrderedDict()
for key in fit['cfit']:
if key.startswith('line'):
line_flux_dict[key.replace('line ', '')] = fit['cfit'][key]
# Compute continuum model
if 'cfit' in fit:
if 'bg {0:03d}'.format(self.N-1) in fit['cfit']:
for ib, beam in enumerate(self.beams):
key = 'bg {0:03d}'.format(ib)
self.beams[ib].background = fit['cfit'][key][0]
cont = fit['cont1d']
for beam in self.beams:
beam.compute_model(spectrum_1d=[cont.wave, cont.flux],
is_cgs=True)
if hasattr(self, 'pscale'):
if (self.pscale is not None):
scale = self.compute_scale_array(self.pscale, beam.wavef)
beam.beam.pscale_array = scale.reshape(beam.sh)
else:
beam.beam.pscale_array = 1.
else:
beam.beam.pscale_array = 1.
for line in line_flux_dict:
line_flux, line_err = line_flux_dict[line]
if line_err == 0:
continue
# Skip if min_line_sn = inf
if not np.isfinite(min_line_sn):
continue
if (line_flux/line_err > min_line_sn) | (line in force_line):
if verbose:
print('Drizzle line -> {0:4s} ({1:.2f} {2:.2f})'.format(line, line_flux/1.e-17, line_err/1.e-17))
line_wave_obs = line_wavelengths[line][0]*(1+z_driz)
if mask_lines:
for beam in self.beams:
beam.oivar = beam.ivar*1
lam = beam.beam.lam_beam
if hasattr(beam.beam, 'pscale_array'):
pscale_array = beam.beam.pscale_array
else:
pscale_array = 1.
# another idea, compute a model for the line itself
# and mask relatively "contaminated" pixels from
# other lines
try:
lm = fit['line1d'][line]
sp = [lm.wave, lm.flux]
except:
key = 'line ' + line
lm = fit['templates'][key]
line_flux = fit['cfit'][key][0]
scl = line_flux/(1+z_driz)
sp = [lm.wave*(1+z_driz), lm.flux*scl]
#lm = fit['line1d'][line]
if ((lm.wave.max() < lam.min()) |
(lm.wave.min() > lam.max())):
continue
#sp = [lm.wave, lm.flux]
if line_flux > 0:
m = beam.compute_model(spectrum_1d=sp,
in_place=False, is_cgs=True)
lmodel = m.reshape(beam.beam.sh_beam)*pscale_array
else:
lmodel = np.zeros(beam.beam.sh_beam)
# if lmodel.max() == 0:
# continue
if 'cfit' in fit:
keys = fit['cfit']
else:
keys = fit['line1d']
beam.extra_lines = beam.contam*0.
for lkey in keys:
if not lkey.startswith('line'):
continue
key = lkey.replace('line ', '')
lf, le = line_flux_dict[key]
# Don't mask if the line missing or undetected
if (lf <= 0): # | (lf < mask_sn_limit*le):
continue
if key != line:
try:
lm = fit['line1d'][lkey]
sp = [lm.wave, lm.flux]
except:
lm = fit['templates'][lkey]
scl = fit['cfit'][lkey][0]/(1+z_driz)
sp = [lm.wave*(1+z_driz), lm.flux*scl]
if ((lm.wave.max() < lam.min()) |
(lm.wave.min() > lam.max())):
continue
m = beam.compute_model(spectrum_1d=sp,
in_place=False,
is_cgs=True)
lcontam = m.reshape(beam.beam.sh_beam)
lcontam *= pscale_array
if lcontam.max() == 0:
# print beam.grism.parent_file, lkey
continue
beam.extra_lines += lcontam
# Only mask if line flux > 0
if line_flux > 0:
extra_msk = lcontam > mask_sn_limit*lmodel
extra_msk &= (lcontam > 0)
extra_msk &= (lmodel > 0)
beam.ivar[extra_msk] *= 0
# Subtract 4959
if (line == 'OIII') & ('cfit' in fit) & mask_4959:
lm = t_o3['line OIII-4959']
scl = fit['cfit']['line OIII'][0]/(1+z_driz)
scl *= 1./(2.98+1)
sp = [lm.wave*(1+z_driz), lm.flux*scl]
if ((lm.wave.max() < lam.min()) |
(lm.wave.min() > lam.max())):
continue
m = beam.compute_model(spectrum_1d=sp,
in_place=False,
is_cgs=True)
lcontam = m.reshape(beam.beam.sh_beam)
lcontam *= pscale_array
if lcontam.max() == 0:
continue
#print('Mask 4959!')
beam.extra_lines += lcontam
hdu = drizzle_to_wavelength(self.beams, ra=self.ra,
dec=self.dec, wave=line_wave_obs,
fcontam=self.fcontam,
**pline)
if mask_lines:
for beam in self.beams:
beam.ivar = beam.oivar*1
delattr(beam, 'oivar')
hdu[0].header['REDSHIFT'] = (z_driz, 'Redshift used')
# for e in [3,4,5,6]:
for e in [-4, -3, -2, -1]:
hdu[e].header['EXTVER'] = line
hdu[e].header['REDSHIFT'] = (z_driz, 'Redshift used')
hdu[e].header['RESTWAVE'] = (line_wavelengths[line][0],
'Line rest wavelength')
saved_lines.append(line)
if len(hdu_full) == 0:
hdu_full = hdu
hdu_full[0].header['NUMLINES'] = (1,
"Number of lines in this file")
else:
hdu_full.extend(hdu[-4:])
hdu_full[0].header['NUMLINES'] += 1
# Make sure DSCI extension is filled. Can be empty for
# lines at the edge of the grism throughput
for f_i in range(hdu[0].header['NDFILT']):
filt_i = hdu[0].header['DFILT{0:02d}'.format(f_i+1)]
if hdu['DWHT', filt_i].data.max() != 0:
hdu_full['DSCI', filt_i] = hdu['DSCI', filt_i]
hdu_full['DWHT', filt_i] = hdu['DWHT', filt_i]
li = hdu_full[0].header['NUMLINES']
hdu_full[0].header['LINE{0:03d}'.format(li)] = line
hdu_full[0].header['FLUX{0:03d}'.format(li)] = (line_flux,
'Line flux, 1e-17 erg/s/cm2')
hdu_full[0].header['ERR{0:03d}'.format(li)] = (line_err,
'Line flux err, 1e-17 erg/s/cm2')
if len(hdu_full) > 0:
hdu_full[0].header['HASLINES'] = (' '.join(saved_lines),
'Lines in this file')
else:
hdu = drizzle_to_wavelength(self.beams, ra=self.ra,
dec=self.dec,
wave=np.median(self.beams[0].wave),
fcontam=self.fcontam,
**pline)
hdu_full = hdu[:-4]
hdu_full[0].header['REDSHIFT'] = (z_driz, 'Redshift used')
hdu_full[0].header['NUMLINES'] = 0
hdu_full[0].header['HASLINES'] = ' '
if include_segmentation:
line_wcs = pywcs.WCS(hdu_full[1].header)
segm = self.drizzle_segmentation(wcsobj=line_wcs)
seg_hdu = pyfits.ImageHDU(data=segm.astype(np.int32), name='SEG')
hdu_full.insert(1, seg_hdu)
if get_ir_psfs:
import grizli.galfit.psf
ir_beams = []
gr_filters = {'G102': ['F105W'], 'G141': ['F105W', 'F125W', 'F140W', 'F160W']}
show_filters = []
for gr in ['G102', 'G141']:
if gr in self.PA:
show_filters.extend(gr_filters[gr])
for pa in self.PA[gr]:
for i in self.PA[gr][pa]:
ir_beams.append(self.beams[i])
if len(ir_beams) > 0:
dp = grizli.galfit.psf.DrizzlePSF(driz_hdu=hdu_full['DSCI'],
beams=self.beams)
for filt in np.unique(show_filters):
if verbose:
print('Get linemap PSF: {0}'.format(filt))
psf = dp.get_psf(ra=dp.driz_wcs.wcs.crval[0],
dec=dp.driz_wcs.wcs.crval[1],
filter=filt,
pixfrac=dp.driz_header['PIXFRAC'],
kernel=dp.driz_header['DRIZKRNL'],
wcs_slice=dp.driz_wcs, get_extended=True,
verbose=False, get_weight=False)
psf[1].header['EXTNAME'] = 'DPSF'
psf[1].header['EXTVER'] = filt
hdu_full.append(psf[1])
if save_fits:
hdu_full.writeto('{0}_{1:05d}.line.fits'.format(self.group_name, self.id), overwrite=True, output_verify='silentfix')
return hdu_full
[docs] def run_full_diagnostics(self, pzfit={}, pspec2={}, pline={},
force_line=['Ha+NII', 'Ha', 'OIII', 'Hb', 'OII'],
GroupFLT=None,
prior=None, zoom=True, verbose=True):
"""TBD
size=20, pixscale=0.1,
pixfrac=0.2, kernel='square'
"""
import copy
# Defaults
pzfit_def, pspec2_def, pline_def = get_redshift_fit_defaults()
if pzfit == {}:
pzfit = pzfit_def
if pspec2 == {}:
pspec2 = pspec2_def
if pline == {}:
pline = pline_def
# Check that keywords allowed
for d, default in zip([pzfit, pspec2, pline],
[pzfit_def, pspec2_def, pline_def]):
for key in d:
if key not in default:
p = d.pop(key)
# Auto generate FWHM (in km/s) to use for line fits
if 'fwhm' in pzfit:
fwhm = pzfit['fwhm']
if pzfit['fwhm'] == 0:
if 'G141' in self.Ngrism:
# WFC3/IR
fwhm = 1200
elif 'G800L' in self.Ngrism:
# ACS/WFC
fwhm = 1400
elif 'G280' in self.Ngrism:
# UVIS
fwhm = 1500
elif 'GRISM' in self.Ngrism:
# WFIRST
fwhm = 350
elif 'G150' in self.Ngrism:
# WFIRST
fwhm = 350
else:
fwhm = 700
# Auto generate delta-wavelength of 2D spectrum
if 'dlam' in pspec2:
dlam = pspec2['dlam']
if dlam == 0:
if 'G141' in self.Ngrism:
dlam = 45
elif 'G800L' in self.Ngrism:
dlam = 40
elif 'G280' in self.Ngrism:
dlam = 18
elif 'GRISM' in self.Ngrism:
dlam = 11
elif 'G150' in self.Ngrism:
dlam = 11
else:
dlam = 25 # G102
# Redshift fit
zfit_in = copy.copy(pzfit)
zfit_in['fwhm'] = fwhm
zfit_in['prior'] = prior
zfit_in['zoom'] = zoom
zfit_in['verbose'] = verbose
if zfit_in['zr'] in [0]:
fit, fig = self.fit_stars(**zfit_in)
else:
fit, fig = self.fit_redshift(**zfit_in)
# Make sure model attributes are set to the continuum model
models = self.reshape_flat(fit['model_cont'])
for j in range(self.N):
self.beams[j].model = models[j]*1
# 2D spectrum
spec_in = copy.copy(pspec2)
spec_in['fit'] = fit
spec_in['dlam'] = dlam
# fig2, hdu2 = self.redshift_fit_twod_figure(**spec_in)#, kwargs=spec2) #dlam=dlam, spatial_scale=spatial_scale, NY=NY)
fig2 = hdu2 = None
# Update master model
if GroupFLT is not None:
try:
ix = GroupFLT.catalog['NUMBER'] == self.beams[0].id
mag = GroupFLT.catalog['MAG_AUTO'][ix].data[0]
except:
mag = 22
sp = fit['cont1d']
GroupFLT.compute_single_model(id, mag=mag, size=-1, store=False,
spectrum_1d=[sp.wave, sp.flux],
is_cgs=True,
get_beams=None, in_place=True)
# 2D lines to drizzle
hdu_full = self.drizzle_fit_lines(fit, pline, force_line=force_line,
save_fits=True)
fit['id'] = self.id
fit['fit_bg'] = self.fit_bg
fit['grism_files'] = [b.grism.parent_file for b in self.beams]
for item in ['A', 'coeffs', 'model_full', 'model_cont']:
if item in fit:
p = fit.pop(item)
#p = fit.pop('coeffs')
np.save('{0}_{1:05d}.zfit.npy'.format(self.group_name, self.id), [fit])
fig.savefig('{0}_{1:05d}.zfit.png'.format(self.group_name, self.id))
#fig2.savefig('{0}_{1:05d}.zfit.2D.png'.format(self.group_name, self.id))
#hdu2.writeto('{0}_{1:05d}.zfit.2D.fits'.format(self.group_name, self.id), overwrite=True, output_verify='silentfix')
label = '# id ra dec zbest '
data = '{0:7d} {1:.6f} {2:.6f} {3:.5f}'.format(self.id, self.ra, self.dec,
fit['zbest'])
for grism in ['G800L', 'G280', 'G102', 'G141', 'GRISM']:
label += ' N{0}'.format(grism)
if grism in self.Ngrism:
data += ' {0:2d}'.format(self.Ngrism[grism])
else:
data += ' {0:2d}'.format(0)
label += ' chi2 DoF '
data += ' {0:14.1f} {1:d} '.format(fit['chibest'], self.DoF)
for line in ['SII', 'Ha', 'OIII', 'Hb', 'Hg', 'OII']:
label += ' {0} {0}_err'.format(line)
if line in fit['line_flux']:
flux = fit['line_flux'][line][0]
err = fit['line_flux'][line][1]
data += ' {0:10.3e} {1:10.3e}'.format(flux, err)
fp = open('{0}_{1:05d}.zfit.dat'.format(self.group_name, self.id), 'w')
fp.write(label+'\n')
fp.write(data+'\n')
fp.close()
fp = open('{0}_{1:05d}.zfit.beams.dat'.format(self.group_name, self.id), 'w')
fp.write('# file filter origin_x origin_y size padx pady bg\n')
for ib, beam in enumerate(self.beams):
msg = '{0:40s} {1:s} {2:5d} {3:5d} {4:5d} {5:5d} {6:5d}'
data = msg.format(beam.grism.parent_file,
beam.grism.filter,
beam.direct.origin[0],
beam.direct.origin[1],
beam.direct.sh[0],
beam.direct.pad[1],
beam.direct.pad[0]
)
if self.fit_bg:
data += ' {0:8.4f}'.format(fit['coeffs_full'][ib])
else:
data += ' {0:8.4f}'.format(0.0)
fp.write(data + '\n')
fp.close()
# Save figures
plt_status = plt.rcParams['interactive']
# if not plt_status:
# plt.close(fig)
# plt.close(fig2)
return fit, fig, fig2, hdu2, hdu_full
[docs] def apply_trace_shift(self, set_to_zero=False):
"""
Set beam.yoffset back to zero
"""
indices = [[i] for i in range(self.N)]
if set_to_zero:
s0 = np.zeros(len(indices))
else:
s0 = [beam.beam.yoffset for beam in self.beams]
args = (self, indices, 0, False, False, True)
self.eval_trace_shift(s0, *args)
# Reset model profile for optimal extractions
for b in self.beams:
# b._parse_from_data()
if hasattr(b, 'has_sys_err'):
delattr(b, 'has_sys_err')
b._parse_from_data(**b._parse_params)
self._parse_beam_arrays()
[docs] def fit_trace_shift(self, split_groups=True, max_shift=5, tol=1.e-2,
verbose=True, lm=False, fit_with_psf=False,
reset=False):
"""TBD
"""
from scipy.optimize import leastsq, minimize
if split_groups:
indices = []
for g in self.PA:
for p in self.PA[g]:
indices.append(self.PA[g][p])
else:
indices = [[i] for i in range(self.N)]
s0 = np.zeros(len(indices))
bounds = np.array([[-max_shift, max_shift]]*len(indices))
args = (self, indices, 0, lm, verbose, fit_with_psf)
if reset:
shifts = np.zeros(len(indices))
out = None
elif lm:
out = leastsq(self.eval_trace_shift, s0, args=args, Dfun=None, full_output=0, col_deriv=0, ftol=1.49012e-08, xtol=1.49012e-08, gtol=0.0, maxfev=0, epsfcn=None, factor=100, diag=None)
shifts = out[0]
else:
out = minimize(self.eval_trace_shift, s0, bounds=bounds, args=args, method='Powell', tol=tol)
if out.x.shape == ():
shifts = [float(out.x)]
else:
shifts = out.x
# Apply to PSF if necessary
args = (self, indices, 0, lm, verbose, True)
self.eval_trace_shift(shifts, *args)
# Reset model profile for optimal extractions
for b in self.beams:
# b._parse_from_data()
b._parse_from_data(**b._parse_params)
# Needed for background modeling
if hasattr(b, 'xp'):
delattr(b, 'xp')
self._parse_beam_arrays()
self.initialize_masked_arrays()
return shifts, out
[docs] @staticmethod
def eval_trace_shift(shifts, self, indices, poly_order, lm, verbose, fit_with_psf):
"""TBD
"""
import scipy.ndimage as nd
for il, l in enumerate(indices):
for i in l:
beam = self.beams[i]
beam.beam.add_ytrace_offset(shifts[il])
if hasattr(self.beams[i].beam, 'psf') & fit_with_psf:
#beam.model = nd.shift(beam.modelf.reshape(beam.sh_beam), (shifts[il], 0))
# This is slow, so run with fit_with_psf=False if possible
beam.init_epsf(yoff=0, # shifts[il],
psf_params=beam.beam.psf_params)
beam.compute_model(use_psf=True)
m = beam.compute_model(in_place=False)
#beam.modelf = beam.model.flatten()
#beam.model = beam.modelf.reshape(beam.beam.sh_beam)
beam.flat_flam = beam.compute_model(in_place=False, is_cgs=True)
else:
# self.beams[i].beam.add_ytrace_offset(shifts[il])
# self.beams[i].compute_model(is_cgs=True)
beam.compute_model(use_psf=False)
self.flat_flam = np.hstack([b.beam.model.flatten() for b in self.beams])
self.poly_order = -1
self.init_poly_coeffs(poly_order=poly_order)
self.fit_bg = False
A = self.A_poly*1
ok_temp = np.sum(A, axis=1) != 0
out_coeffs = np.zeros(A.shape[0])
y = self.scif
out = np.linalg.lstsq(A.T, y, rcond=utils.LSTSQ_RCOND)
lstsq_coeff, residuals, rank, s = out
coeffs = lstsq_coeff
out_coeffs = np.zeros(A.shape[0])
out_coeffs[ok_temp] = coeffs
modelf = np.dot(out_coeffs, A)
if lm:
# L-M, return residuals
if verbose:
print('{0} [{1}]'.format(utils.NO_NEWLINE, ' '.join(['{0:5.2f}'.format(s) for s in shifts])))
return ((self.scif-modelf)*self.sivarf)[self.fit_mask]
chi2 = np.sum(((self.scif - modelf)**2*self.ivarf)[self.fit_mask])
if verbose:
print('{0} [{1}] {2:6.2f}'.format(utils.NO_NEWLINE, ' '.join(['{0:5.2f}'.format(s) for s in shifts]), chi2/self.DoF))
return chi2/self.DoF
[docs] def drizzle_grisms_and_PAs(self, size=10, fcontam=0, flambda=False, scale=1, pixfrac=0.5, kernel='square', usewcs=False, tfit=None, diff=True, grism_list=['G800L', 'G102', 'G141', 'GR150C', 'GR150R', 'F090W', 'F115W', 'F150W', 'F200W', 'F356W', 'F300M', 'F335M', 'F360M', 'F410M', 'F430M','F460M','F480M','F444W'], mask_segmentation=True, reset_model=True, make_figure=True, fig_args=dict(mask_segmentation=True, average_only=False, scale_size=1, cmap='viridis_r'), **kwargs):
"""Make figure showing spectra at different orients/grisms
TBD
"""
from matplotlib.ticker import MultipleLocator
#import pysynphot as S
if usewcs:
drizzle_function = drizzle_2d_spectrum_wcs
else:
drizzle_function = drizzle_2d_spectrum
if 'zfit' in kwargs:
tfit = kwargs['zfit']
NX = len(self.PA)
NY = 0
for g in self.PA:
NY = np.maximum(NY, len(self.PA[g]))
NY += 1
# keys = list(self.PA)
keys = []
for key in grism_list:
if key in self.PA:
keys.append(key)
if tfit is not None:
if 'coeffs_full' in tfit:
bg = tfit['coeffs_full'][:self.N]
z_cont = tfit['zbest']
else:
# fitting.GroupFitter
z_cont = tfit['z']
bg = []
for k in tfit['cfit']:
if k.startswith('bg '):
bg.append(tfit['cfit'][k][0])
bg = np.array(bg)
else:
# Fit background
try:
out = self.xfit_at_z(z=0, templates={}, fitter='lstsq',
poly_order=3, fit_background=True)
bg = out[-3][:self.N]
except:
bg = [0]*self.N
for ib, beam in enumerate(self.beams):
beam.bg = bg[ib]
prim = pyfits.PrimaryHDU()
h0 = prim.header
h0['ID'] = (self.id, 'Object ID')
h0['RA'] = (self.ra, 'Right ascension')
h0['DEC'] = (self.dec, 'Declination')
h0['ISFLAM'] = (flambda, 'Pixels in f-lam units')
h0['FCONTAM'] = (fcontam, 'Contamination parameter')
h0['NGRISM'] = (len(keys), 'Number of grisms')
all_hdus = []
for ig, g in enumerate(keys):
all_beams = []
hdus = []
pas = list(self.PA[g].keys())
pas.sort()
h0['GRISM{0:03d}'.format(ig+1)] = (g, 'Grism name')
h0['N'+g] = (len(pas), 'Number of PAs for grism '+g)
for ipa, pa in enumerate(pas):
h0[g+'{0:02d}'.format(ipa+1)] = (pa, 'PA')
beams = [self.beams[i] for i in self.PA[g][pa]]
all_beams.extend(beams)
#dlam = np.ceil(np.diff(beams[0].beam.lam)[0])*scale
dlam = GRISM_LIMITS[g][2]*scale
data = [beam.grism['SCI']-beam.contam-beam.bg
for beam in beams]
hdu = drizzle_function(beams, data=data,
wlimit=GRISM_LIMITS[g], dlam=dlam,
spatial_scale=scale, NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0, ds9=None,
mask_segmentation=mask_segmentation)
hdu[0].header['RA'] = (self.ra, 'Right ascension')
hdu[0].header['DEC'] = (self.dec, 'Declination')
hdu[0].header['GRISM'] = (g, 'Grism')
hdu[0].header['PA'] = (pa, 'Dispersion PA')
hdu[0].header['ISFLAM'] = (flambda, 'Pixels in f-lam units')
hdu[0].header['CONF'] = (beams[0].beam.conf.conf_file,
'Configuration file')
hdu[0].header['DLAM0'] = (np.median(np.diff(beams[0].wave)),
'Native dispersion per pix')
# Contam
data = [beam.contam for beam in beams]
hdu_contam = drizzle_function(beams, data=data,
wlimit=GRISM_LIMITS[g], dlam=dlam,
spatial_scale=scale, NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0, ds9=None,
mask_segmentation=mask_segmentation)
hdu_contam[1].header['EXTNAME'] = 'CONTAM'
hdu.append(hdu_contam[1])
# Continuum model
if tfit is not None:
m = tfit['cont1d']
for beam in beams:
beam.compute_model(spectrum_1d=[m.wave, m.flux],
is_cgs=True)
else:
if reset_model:
# simple flat spectrum
for beam in beams:
beam.compute_model()
data = []
for beam in beams:
if hasattr(beam.beam, 'pscale_array'):
data.append(beam.beam.model*beam.beam.pscale_array)
else:
data.append(beam.beam.model)
hdu_model = drizzle_function(beams, data=data,
wlimit=GRISM_LIMITS[g], dlam=dlam,
spatial_scale=scale, NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0, ds9=None,
mask_segmentation=mask_segmentation)
hdu_model[1].header['EXTNAME'] = 'MODEL'
if tfit is not None:
hdu_model[1].header['CONTIN1D'] = (True, 'Model is fit continuum')
hdu_model[1].header['REDSHIFT'] = (z_cont, 'Redshift of the continuum spectrum')
else:
hdu_model[1].header['CONTIN1D'] = (False, 'Model is fit continuum')
hdu.append(hdu_model[1])
# Line kernel
if (not usewcs):
h = hdu[1].header
# header keywords scaled to um
toA = 1.e4
gau = utils.SpectrumTemplate(central_wave=h['CRVAL1']*toA,
fwhm=h['CD1_1']*toA)
if reset_model:
for beam in beams:
beam.compute_model(spectrum_1d=[gau.wave,
gau.flux],
is_cgs=True)
data = [beam.beam.model for beam in beams]
h_kern = drizzle_function(beams, data=data,
wlimit=GRISM_LIMITS[g],
dlam=dlam,
spatial_scale=scale, NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0, fill_wht=True,
ds9=None,
mask_segmentation=mask_segmentation)
kern = h_kern[1].data[:, h['CRPIX1']-1-size:h['CRPIX1']-1+size]
hdu_kern = pyfits.ImageHDU(data=kern, header=h_kern[1].header, name='KERNEL')
hdu.append(hdu_kern)
else:
hdu['DSCI'].header['EXTNAME'] = 'KERNEL'
# Pull out zeroth extension
for k in hdu[0].header:
hdu[1].header[k] = hdu[0].header[k]
for e in hdu[1:]:
e.header['EXTVER'] = '{0},{1}'.format(g, pa)
hdus.append(hdu[1:])
# Stack of each grism
data = [beam.grism['SCI']-beam.contam-beam.bg
for beam in all_beams]
hdu = drizzle_function(all_beams, data=data,
wlimit=GRISM_LIMITS[g], dlam=dlam,
spatial_scale=scale, NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=fcontam, ds9=None,
mask_segmentation=mask_segmentation)
hdu[0].header['RA'] = (self.ra, 'Right ascension')
hdu[0].header['DEC'] = (self.dec, 'Declination')
hdu[0].header['GRISM'] = (g, 'Grism')
hdu[0].header['ISFLAM'] = (flambda, 'Pixels in f-lam units')
hdu[0].header['CONF'] = (beams[0].beam.conf.conf_file,
'Configuration file')
hdu[0].header['DLAM0'] = (np.median(np.diff(beams[0].wave)),
'Native dispersion per pix')
# Full continuum model
if tfit is not None:
if diff > 1:
m = tfit['line1d']
else:
m = tfit['cont1d']
for beam in all_beams:
beam.compute_model(spectrum_1d=[m.wave, m.flux],
is_cgs=True)
else:
if reset_model:
for beam in all_beams:
beam.compute_model()
#data = [beam.beam.model for beam in all_beams]
data = []
for beam in all_beams:
if hasattr(beam.beam, 'pscale_array'):
data.append(beam.beam.model*beam.beam.pscale_array)
else:
data.append(beam.beam.model)
hdu_model = drizzle_function(all_beams, data=data,
wlimit=GRISM_LIMITS[g], dlam=dlam,
spatial_scale=scale, NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=fcontam, ds9=None,
mask_segmentation=mask_segmentation)
hdu_model[1].header['EXTNAME'] = 'MODEL'
if tfit is not None:
hdu_model[1].header['CONTIN1D'] = (True, 'Model is fit continuum')
hdu_model[1].header['REDSHIFT'] = (z_cont, 'Redshift of the continuum spectrum')
else:
hdu_model[1].header['CONTIN1D'] = (False, 'Model is fit continuum')
hdu.append(hdu_model[1])
# Full kernel
h = hdu[1].header
#gau = S.GaussianSource(1.e-17, h['CRVAL1'], h['CD1_1']*1)
toA = 1.e4
#gau = S.GaussianSource(1., h['CRVAL1']*toA, h['CD1_1']*toA)
gau = utils.SpectrumTemplate(central_wave=h['CRVAL1']*toA, fwhm=h['CD1_1']*toA)
if reset_model:
for beam in all_beams:
beam.compute_model(spectrum_1d=[gau.wave, gau.flux],
is_cgs=True)
data = [beam.beam.model for beam in all_beams]
h_kern = drizzle_function(all_beams, data=data,
wlimit=GRISM_LIMITS[g], dlam=dlam,
spatial_scale=scale, NY=size,
pixfrac=pixfrac,
kernel=kernel,
convert_to_flambda=flambda,
fcontam=0, fill_wht=True, ds9=None,
mask_segmentation=mask_segmentation)
kern = h_kern[1].data[:, int(h['CRPIX1'])-1-size:int(h['CRPIX1'])-1+size]
hdu_kern = pyfits.ImageHDU(data=kern, header=h_kern[1].header, name='KERNEL')
hdu.append(hdu_kern)
# Pull out zeroth extension
for k in hdu[0].header:
hdu[1].header[k] = hdu[0].header[k]
for e in hdu[1:]:
e.header['EXTVER'] = '{0}'.format(g)
hdus.append(hdu[1:])
all_hdus.extend(hdus)
output_hdu = pyfits.HDUList([prim])
for hdu in all_hdus:
output_hdu.extend(hdu)
if make_figure:
fig = show_drizzle_HDU(output_hdu, diff=diff, **fig_args)
return output_hdu, fig
else:
return output_hdu # all_hdus
[docs] def flag_with_drizzled(self, hdul, sigma=4, update=True, interp='nearest', verbose=True):
"""
Update `MultiBeam` masks based on the blotted drizzled combined image
[in progress ... xxx]
Parameters
----------
hdul : `~astropy.io.fits.HDUList`
FITS HDU list output from `drizzle_grisms_and_PAs` or read from a
`stack.fits` file.
sigma : float
Residual threshold to flag.
update : bool
Update the mask.
interp : str
Interpolation method for `~drizzlepac.ablot`.
Returns
-------
Updates the individual `fit_mask` attributes of the individual beams
if `update==True`.
"""
try:
from drizzle.doblot import doblot
blotter = doblot
except:
from drizzlepac import ablot
blotter = ablot.do_blot
# Read the drizzled arrays
Ng = hdul[0].header['NGRISM']
ref_wcs = {}
ref_data = {}
flag_grism = {}
for i in range(Ng):
g = hdul[0].header['GRISM{0:03d}'.format(i+1)]
ref_wcs[g] = pywcs.WCS(hdul['SCI', g].header)
ref_wcs[g].pscale = utils.get_wcs_pscale(ref_wcs[g])
ref_data[g] = hdul['SCI', g].data
flag_grism[g] = hdul[0].header['N{0}'.format(g)] > 1
# Do the masking
for i, beam in enumerate(self.beams):
g = beam.grism.filter
if not flag_grism[g]:
continue
beam_header, flt_wcs = beam.full_2d_wcs()
blotted = blotter(ref_data[g], ref_wcs[g],
flt_wcs, 1,
coeffs=True, interp=interp, sinscl=1.0,
stepsize=10, wcsmap=None)
resid = (beam.grism['SCI'] - beam.contam - blotted)
resid *= np.sqrt(beam.ivar)
blot_mask = (blotted != 0) & (np.abs(resid) < sigma)
if verbose:
print('Beam {0:>3d}: {1:>4d} new masked pixels'.format(i, beam.fit_mask.sum() - (beam.fit_mask & blot_mask.flatten()).sum()))
if update:
beam.fit_mask &= blot_mask.flatten()
if update:
self._parse_beams()
self.initialize_masked_arrays()
[docs] def oned_spectrum(self, tfit=None, get_contam=True, get_background=False, masked_model=None, **kwargs):
"""Compute full 1D spectrum with optional best-fit model
Parameters
----------
bin : float / int
Bin factor relative to the size of the native spectral bins of a
given grism.
tfit : dict
Output of `~grizli.fitting.mb.template_at_z`.
Returns
-------
sp : dict
Dictionary of the extracted 1D spectra. Keys are the grism
names and the values are `~astropy.table.Table` objects.
"""
import astropy.units as u
# "Flat" spectrum to perform flux calibration
if self.Nphot > 0:
flat_data = self.flat_flam[self.fit_mask[:-self.Nphotbands]]
else:
flat_data = self.flat_flam[self.fit_mask]
sp_flat = self.optimal_extract(flat_data, **kwargs)
# Best-fit line and continuum models, with background fit
if tfit is not None:
bg_model = self.get_flat_background(tfit['coeffs'],
apply_mask=True)
line_model = self.get_flat_model([tfit['line1d'].wave,
tfit['line1d'].flux])
cont_model = self.get_flat_model([tfit['line1d'].wave,
tfit['cont1d'].flux])
sp_line = self.optimal_extract(line_model, **kwargs)
sp_cont = self.optimal_extract(cont_model, **kwargs)
elif masked_model is not None:
bg_model = 0.
sp_model = self.optimal_extract(masked_model, **kwargs)
else:
bg_model = 0.
# Optimal spectral extraction
sp = self.optimal_extract(self.scif_mask[:self.Nspec]-bg_model, **kwargs)
if get_contam:
spc = self.optimal_extract(self.contamf_mask[:self.Nspec],
**kwargs)
if (tfit is not None) & (get_background):
bgm = self.get_flat_background(tfit['coeffs'], apply_mask=True)
sp_bg = self.optimal_extract(bgm[:self.Nspec], **kwargs)
else:
sp_bg = None
# Loop through grisms, change units and add fit columns
# NB: setting units to "count / s" to comply with FITS standard,
# where count / s = electron / s
for k in sp:
sp[k]['flat'] = sp_flat[k]['flux']
flat_unit = (u.count / u.s) / (u.erg / u.s / u.cm**2 / u.AA)
sp[k]['flat'].unit = flat_unit
sp[k]['flux'].unit = u.count / u.s
sp[k]['err'].unit = u.count / u.s
if get_contam:
sp[k]['contam'] = spc[k]['flux']
sp[k]['contam'].unit = u.count / u.s
if tfit is not None:
sp[k]['line'] = sp_line[k]['flux']
sp[k]['line'].unit = u.count / u.s
sp[k]['cont'] = sp_cont[k]['flux']
sp[k]['cont'].unit = u.count / u.s
if masked_model is not None:
sp[k]['model'] = sp_model[k]['flux']
sp[k]['model'].unit = u.count / u.s
if sp_bg is not None:
sp[k]['background'] = sp_bg[k]['flux']
sp[k]['background'].unit = u.count / u.s
sp[k].meta['GRISM'] = (k, 'Grism name')
# Metadata
exptime = count = 0
for pa in self.PA[k]:
for i in self.PA[k][pa]:
exptime += self.beams[i].grism.header['EXPTIME']
count += 1
parent = (self.beams[i].grism.parent_file, 'Parent file')
sp[k].meta['FILE{0:04d}'.format(count)] = parent
sp[k].meta['NEXP'] = (count, 'Number of exposures')
sp[k].meta['EXPTIME'] = (exptime, 'Total exposure time')
sp[k].meta['NPA'] = (len(self.PA[k]), 'Number of PAs')
# PSCALE
if hasattr(self, 'pscale'):
if (self.pscale is not None):
pscale = self.compute_scale_array(self.pscale,
sp[k]['wave'])
sp[k]['pscale'] = pscale
sp[k].meta['PSCALEN'] = (len(self.pscale)-1,
'PSCALE order')
for i, p in enumerate(self.pscale):
sp[k].meta['PSCALE{0}'.format(i)] = (p,
'PSCALE parameter {0}'.format(i))
return sp
[docs] def oned_spectrum_to_hdu(self, sp=None, outputfile=None, units=None, **kwargs):
"""Generate 1D spectra fits HDUList
Parameters
----------
sp : optional, dict
Output of `~grizli.multifit.MultiBeam.oned_spectrum`. If None,
then run that function with `**kwargs`.
outputfile : None, str
If a string supplied, then write the `~astropy.io.fits.HDUList` to
a file.
Returns
-------
hdul : `~astropy.io.fits.HDUList`
FITS version of the 1D spectrum tables.
"""
from astropy.io.fits.convenience import table_to_hdu
# Generate the spectrum if necessary
if sp is None:
sp = self.oned_spectrum(**kwargs)
# Metadata in PrimaryHDU
prim = pyfits.PrimaryHDU()
prim.header['ID'] = (self.id, 'Object ID')
prim.header['RA'] = (self.ra, 'Right Ascension')
prim.header['DEC'] = (self.dec, 'Declination')
prim.header['TARGET'] = (self.group_name, 'Target Name')
prim.header['MW_EBV'] = (self.MW_EBV, 'Galactic extinction E(B-V)')
for g in ['G102', 'G141', 'G800L']:
if g in sp:
prim.header['N_{0}'.format(g)] = sp[g].meta['NEXP']
prim.header['T_{0}'.format(g)] = sp[g].meta['EXPTIME']
prim.header['PA_{0}'.format(g)] = sp[g].meta['NPA']
else:
prim.header['N_{0}'.format(g)] = (0, 'Number of exposures')
prim.header['T_{0}'.format(g)] = (0, 'Total exposure time')
prim.header['PA_{0}'.format(g)] = (0, 'Number of PAs')
for i, k in enumerate(sp):
prim.header['GRISM{0:03d}'.format(i+1)] = (k, 'Grism name')
# Generate HDUList
hdul = [prim]
for k in sp:
hdu = table_to_hdu(sp[k])
hdu.header['EXTNAME'] = k
hdul.append(hdu)
# Outputs
hdul = pyfits.HDUList(hdul)
if outputfile is None:
return hdul
else:
hdul.writeto(outputfile, overwrite=True)
return hdul
[docs] def make_simple_hdulist(self):
"""
Make a`~astropy.io.fits.HDUList` object with just a simple
PrimaryHDU
"""
p = pyfits.PrimaryHDU()
p.header['ID'] = (self.id, 'Object ID')
p.header['RA'] = (self.ra, 'R.A.')
p.header['DEC'] = (self.dec, 'Decl.')
p.header['NINPUT'] = (len(self.beams), 'Number of drizzled beams')
p.header['HASLINES'] = ('', 'Lines in this file')
for i, beam in enumerate(self.beams):
p.header['FILE{0:04d}'.format(i+1)] = (beam.grism.parent_file,
'Parent filename')
p.header['GRIS{0:04d}'.format(i+1)] = (beam.grism.filter,
'Beam grism element')
p.header['PA{0:04d}'.format(i+1)] = (beam.get_dispersion_PA(),
'PA of dispersion axis')
return pyfits.HDUList(p)
[docs] def check_for_bad_PAs(self, poly_order=1, chi2_threshold=1.5, fit_background=True, reinit=True):
"""
"""
wave = np.linspace(2000, 2.5e4, 100)
poly_templates = utils.polynomial_templates(wave, order=poly_order)
fit_log = OrderedDict()
keep_dict = {}
has_bad = False
keep_beams = []
for g in self.PA:
fit_log[g] = OrderedDict()
keep_dict[g] = []
for pa in self.PA[g]:
beams = [self.beams[i] for i in self.PA[g][pa]]
mb_i = MultiBeam(beams, fcontam=self.fcontam,
sys_err=self.sys_err, min_sens=self.min_sens,
min_mask=self.min_mask,
mask_resid=self.mask_resid,
MW_EBV=self.MW_EBV)
try:
chi2, _, _, _ = mb_i.xfit_at_z(z=0,
templates=poly_templates,
fit_background=fit_background)
except:
chi2 = 1e30
if False:
p_i = mb_i.template_at_z(z=0, templates=poly_templates, fit_background=fit_background, fitter='lstsq', fwhm=1400, get_uncertainties=2)
fit_log[g][pa] = {'chi2': chi2, 'DoF': mb_i.DoF,
'chi_nu': chi2/np.maximum(mb_i.DoF, 1)}
min_chinu = 1e30
for pa in self.PA[g]:
min_chinu = np.minimum(min_chinu, fit_log[g][pa]['chi_nu'])
fit_log[g]['min_chinu'] = min_chinu
for pa in self.PA[g]:
fit_log[g][pa]['chinu_ratio'] = fit_log[g][pa]['chi_nu']/min_chinu
if fit_log[g][pa]['chinu_ratio'] < chi2_threshold:
keep_dict[g].append(pa)
keep_beams.extend([self.beams[i] for i in self.PA[g][pa]])
else:
has_bad = True
if reinit:
self.beams = keep_beams
self._parse_beams(psf=self.psf_param_dict is not None)
return fit_log, keep_dict, has_bad
[docs]def get_redshift_fit_defaults():
"""TBD
"""
pzfit_def = dict(zr=[0.5, 1.6], dz=[0.005, 0.0004], fwhm=0,
poly_order=0, fit_background=True,
delta_chi2_threshold=0.004, fitter='nnls',
prior=None, templates={}, figsize=[8, 5],
fsps_templates=False)
pspec2_def = dict(dlam=0, spatial_scale=1, NY=20, figsize=[8, 3.5])
pline_def = dict(size=20, pixscale=0.1, pixfrac=0.2, kernel='square',
wcs=None)
return pzfit_def, pspec2_def, pline_def
[docs]def drizzle_2d_spectrum(beams, data=None, wlimit=[1.05, 1.75], dlam=50,
spatial_scale=1, NY=10, pixfrac=0.6, kernel='square',
convert_to_flambda=True, fcontam=0.2, fill_wht=False,
ds9=None, mask_segmentation=True):
"""Drizzle 2D spectrum from a list of beams
Parameters
----------
beams : list of `~.model.BeamCutout` objects
data : None or list
optionally, drizzle data specified in this list rather than the
contamination-subtracted arrays from each beam.
wlimit : [float, float]
Limits on the wavelength array to drizzle ([wlim, wmax])
dlam : float
Delta wavelength per pixel
spatial_scale : float
Relative scaling of the spatial axis (1 = native pixels)
NY : int
Size of the cutout in the spatial dimension, in output pixels
pixfrac : float
Drizzle PIXFRAC (for `kernel` = 'point')
kernel : str, ('square' or 'point')
Drizzle kernel to use
convert_to_flambda : bool, float
Convert the 2D spectrum to physical units using the sensitivity curves
and if float provided, scale the flux densities by that value
fcontam: float
Factor by which to scale the contamination arrays and add to the
pixel variances.
fill_wht: bool
Fill `wht==0` pixels of the beam weights with the median nonzero
value.
ds9: `~grizli.ds9.DS9`
Show intermediate steps of the drizzling
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDUList with the drizzled 2D spectrum and weight arrays
"""
from astropy import log
# try:
# import drizzle
# if drizzle.__version__ != '1.12.99':
# # Not the fork that works for all input/output arrays
# raise(ImportError)
#
# #print('drizzle!!')
# from drizzle.dodrizzle import dodrizzle
# drizzler = dodrizzle
# dfillval = '0'
# except:
from drizzlepac import adrizzle
adrizzle.log.setLevel('ERROR')
drizzler = adrizzle.do_driz
dfillval = 0
log.setLevel('ERROR')
# log.disable_warnings_logging()
NX = int(np.round(np.diff(wlimit)[0]*1.e4/dlam)) // 2
center = np.mean(wlimit[:2])*1.e4
out_header, output_wcs = utils.full_spectrum_wcsheader(center_wave=center,
dlam=dlam, NX=NX,
spatial_scale=spatial_scale, NY=NY)
sh = (out_header['NAXIS2'], out_header['NAXIS1'])
if not hasattr(output_wcs, '_naxis1'):
output_wcs._naxis2, output_wcs._naxis1 = sh
outsci = np.zeros(sh, dtype=np.float32)
outwht = np.zeros(sh, dtype=np.float32)
outctx = np.zeros(sh, dtype=np.int32)
outvar = np.zeros(sh, dtype=np.float32)
outwv = np.zeros(sh, dtype=np.float32)
outcv = np.zeros(sh, dtype=np.int32)
if data is None:
data = []
for i, beam in enumerate(beams):
# Contamination-subtracted
beam_data = beam.grism.data['SCI'] - beam.contam
data.append(beam_data)
for i, beam in enumerate(beams):
# Get specific WCS for each beam
beam_header, beam_wcs = beam.full_2d_wcs()
if not hasattr(beam_wcs, 'pixel_shape'):
beam_wcs.pixel_shape = beam_wcs._naxis1, beam_wcs._naxis2
if not hasattr(beam_wcs, '_naxis1'):
beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis
# Downweight contamination
# wht = 1/beam.ivar + (fcontam*beam.contam)**2
# wht = np.asarray(1/wht,dtype=np.float32)
# wht[~np.isfinite(wht)] = 0.
contam_weight = np.exp(-(fcontam*np.abs(beam.contam)*np.sqrt(beam.ivar)))
wht = beam.ivar*contam_weight
wht[~np.isfinite(wht)] = 0.
contam_weight[beam.ivar == 0] = 0
if fill_wht:
wht_mask = wht == 0
med_wht = np.median(wht[~wht_mask])
wht[wht_mask] = med_wht
#print('xx Fill weight: {0}'.format(med_wht))
data_i = data[i]*1.
scl = 1.
if convert_to_flambda:
#data_i *= convert_to_flambda/beam.beam.sensitivity
#wht *= (beam.beam.sensitivity/convert_to_flambda)**2
scl = convert_to_flambda # /1.e-17
scl *= 1./beam.flat_flam.reshape(beam.beam.sh_beam).sum(axis=0)
#scl = convert_to_flambda/beam.beam.sensitivity
data_i *= scl
wht *= (1/scl)**2
#contam_weight *= scl
wht[~np.isfinite(data_i+scl)] = 0
contam_weight[~np.isfinite(data_i+scl)] = 0
data_i[~np.isfinite(data_i+scl)] = 0
# Go drizzle
# Contamination-cleaned
drizzler(data_i, beam_wcs, wht, output_wcs,
outsci, outwht, outctx, 1., 'cps', 1,
wcslin_pscale=1., uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval)
# For variance
drizzler(contam_weight, beam_wcs, wht, output_wcs,
outvar, outwv, outcv, 1., 'cps', 1,
wcslin_pscale=1., uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval)
if ds9 is not None:
ds9.view(outsci/output_wcs.pscale**2, header=out_header)
# if False:
# # Plot the spectra for testing
# w, f, e = beam.beam.trace_extract(data_i, ivar=wht, r=3)
# clip = (f/e > 0.5)
# clip &= (e < 2*np.median(e[clip]))
# plt.errorbar(w[clip], f[clip], e[clip], marker='.', color='k', alpha=0.5, ecolor='0.8', linestyle='None')
# dw = np.median(np.diff(w))
# Correct for drizzle scaling
area_ratio = 1./output_wcs.pscale**2
# Preserve flux (has to preserve aperture flux along spatial axis but
# average in spectral axis).
#area_ratio *= spatial_scale
# preserve flux density
flux_density_scale = spatial_scale**2
# science
outsci *= area_ratio*flux_density_scale
# variance
outvar *= area_ratio/outwv*flux_density_scale**2
outwht = 1/outvar
outwht[(outvar == 0) | (~np.isfinite(outwht))] = 0
# if True:
# # Plot for testing....
# yp, xp = np.indices(outsci.shape)
# mask = np.abs(yp-NY) <= 3/spatial_scale
# fl = (outsci*mask).sum(axis=0)
# flv = (1/outwht*mask).sum(axis=0)
#
# wi = grizli.stack.StackedSpectrum.get_wavelength_from_header(out_header)
#
# plt.errorbar(wi[:-1], fl[1:], np.sqrt(flv)[1:], alpha=0.8) #*area_ratio)
# return outwht, outsci, outvar, outwv, output_wcs.pscale
p = pyfits.PrimaryHDU()
p.header['ID'] = (beams[0].id, 'Object ID')
p.header['WMIN'] = (wlimit[0], 'Minimum wavelength')
p.header['WMAX'] = (wlimit[1], 'Maximum wavelength')
p.header['DLAM'] = (dlam, 'Delta wavelength')
p.header['SSCALE'] = (spatial_scale, 'Spatial scale factor w.r.t native')
p.header['FCONTAM'] = (fcontam, 'Contamination weight')
p.header['PIXFRAC'] = (pixfrac, 'Drizzle PIXFRAC')
p.header['DRIZKRNL'] = (kernel, 'Drizzle kernel')
p.header['BEAM'] = (beams[0].beam.beam, 'Grism order')
p.header['NINPUT'] = (len(beams), 'Number of drizzled beams')
exptime = 0.
for i, beam in enumerate(beams):
p.header['FILE{0:04d}'.format(i+1)] = (beam.grism.parent_file,
'Parent filename')
p.header['GRIS{0:04d}'.format(i+1)] = (beam.grism.filter,
'Beam grism element')
p.header['PA{0:04d}'.format(i+1)] = (beam.get_dispersion_PA(),
'PA of dispersion axis')
exptime += beam.grism.exptime
p.header['EXPTIME'] = (exptime, 'Total exposure time [s]')
h = out_header.copy()
grism_sci = pyfits.ImageHDU(data=outsci, header=h, name='SCI')
grism_wht = pyfits.ImageHDU(data=outwht, header=h, name='WHT')
hdul = pyfits.HDUList([p, grism_sci, grism_wht])
return hdul
[docs]def drizzle_to_wavelength(beams, wcs=None, ra=0., dec=0., wave=1.e4, size=5, pixscale=0.1, pixfrac=0.6, kernel='square', theta=0., direct_extension='REF', fcontam=0.2, custom_key=None, ds9=None):
"""Drizzle a cutout at a specific wavelength from a list of `~grizli.model.BeamCutout` objects
Parameters
----------
beams : list of `~.model.BeamCutout` objects.
wcs : `~astropy.wcs.WCS` or None
Pre-determined WCS. If not specified, generate one based on ``ra``,
``dec``, ``pixscale`` and ``pixscale``.
ra, dec, wave : float
Sky coordinates and central wavelength
size : float
Size of the output thumbnail, in arcsec
pixscale : float
Pixel scale of the output thumbnail, in arcsec
pixfrac : float
Drizzle PIXFRAC (for ``kernel`` = 'point')
kernel : str, ('square' or 'point')
Drizzle kernel to use
theta : float
Position angle of output WCS
direct_extension : str, ('SCI' or 'REF')
Extension of ``self.direct.data`` do drizzle for the thumbnail
fcontam: float
Factor by which to scale the contamination arrays and add to the
pixel variances.
custom_key : str
Key of `beam.grism.data` dictionary to use as the science array for the
drizzled output
ds9 : `~grizli.ds9.DS9`, optional
Display each step of the drizzling to an open DS9 window
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDUList with the drizzled thumbnail, line and continuum
cutouts.
"""
# try:
# import drizzle
# if drizzle.__version__ != '1.12.99':
# # Not the fork that works for all input/output arrays
# raise(ImportError)
#
# #print('drizzle!!')
# from drizzle.dodrizzle import dodrizzle
# drizzler = dodrizzle
# dfillval = '0'
# except:
from drizzlepac import adrizzle
adrizzle.log.setLevel('ERROR')
drizzler = adrizzle.do_driz
dfillval = 0
# Nothing to do
if len(beams) == 0:
return False
# Get output header and WCS
if wcs is None:
header, output_wcs = utils.make_wcsheader(ra=ra, dec=dec,
size=size,
theta=theta,
pixscale=pixscale,
get_hdu=False)
else:
output_wcs = wcs.copy()
if not hasattr(output_wcs, 'pscale'):
output_wcs.pscale = utils.get_wcs_pscale(output_wcs)
header = utils.to_header(output_wcs, relax=True)
if not hasattr(output_wcs, '_naxis1'):
output_wcs._naxis1, output_wcs._naxis2 = output_wcs._naxis
# Initialize data
sh = (header['NAXIS2'], header['NAXIS1'])
outsci = np.zeros(sh, dtype=np.float32)
outwht = np.zeros(sh, dtype=np.float32)
outctx = np.zeros(sh, dtype=np.int32)
coutsci = np.zeros(sh, dtype=np.float32)
coutwht = np.zeros(sh, dtype=np.float32)
coutctx = np.zeros(sh, dtype=np.int32)
xoutsci = np.zeros(sh, dtype=np.float32)
xoutwht = np.zeros(sh, dtype=np.float32)
xoutctx = np.zeros(sh, dtype=np.int32)
#direct_filters = np.unique([b.direct.filter for b in self.beams])
all_direct_filters = []
for beam in beams:
if direct_extension == 'REF':
if beam.direct['REF'] is None:
filt_i = beam.direct.ref_filter
direct_extension = 'SCI'
else:
filt_i = beam.direct.filter
all_direct_filters.append(filt_i)
direct_filters = np.unique(all_direct_filters)
doutsci, doutwht, doutctx = {}, {}, {}
for f in direct_filters:
doutsci[f] = np.zeros(sh, dtype=np.float32)
doutwht[f] = np.zeros(sh, dtype=np.float32)
doutctx[f] = np.zeros(sh, dtype=np.int32)
# doutsci = np.zeros(sh, dtype=np.float32)
# doutwht = np.zeros(sh, dtype=np.float32)
# doutctx = np.zeros(sh, dtype=np.int32)
# Loop through beams and run drizzle
for i, beam in enumerate(beams):
# Get specific wavelength WCS for each beam
beam_header, beam_wcs = beam.get_wavelength_wcs(wave)
if not hasattr(beam_wcs, 'pixel_shape'):
beam_wcs.pixel_shape = beam_wcs._naxis1, beam_wcs._naxis2
if not hasattr(beam_wcs, '_naxis1'):
beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis
# Make sure CRPIX set correctly for the SIP header
for j in [0, 1]:
# if beam_wcs.sip is not None:
# beam_wcs.sip.crpix[j] = beam_wcs.wcs.crpix[j]
if beam.direct.wcs.sip is not None:
beam.direct.wcs.sip.crpix[j] = beam.direct.wcs.wcs.crpix[j]
for wcs_ext in [beam_wcs.sip]:
if wcs_ext is not None:
wcs_ext.crpix[j] = beam_wcs.wcs.crpix[j]
# ACS requires additional wcs attributes
ACS_CRPIX = [4096/2, 2048/2]
dx_crpix = beam_wcs.wcs.crpix[0] - ACS_CRPIX[0]
dy_crpix = beam_wcs.wcs.crpix[1] - ACS_CRPIX[1]
for wcs_ext in [beam_wcs.cpdis1, beam_wcs.cpdis2,
beam_wcs.det2im1, beam_wcs.det2im2]:
if wcs_ext is not None:
wcs_ext.crval[0] += dx_crpix
wcs_ext.crval[1] += dy_crpix
if custom_key is None:
beam_data = beam.grism.data['SCI'] - beam.contam
if hasattr(beam, 'background'):
beam_data -= beam.background
if hasattr(beam, 'extra_lines'):
beam_data -= beam.extra_lines
beam_continuum = beam.beam.model*1
if hasattr(beam.beam, 'pscale_array'):
beam_continuum *= beam.beam.pscale_array
else:
beam_data = beam.grism.data[custom_key]*1
beam_continuum = beam_data * 0
# Downweight contamination
if fcontam > 0:
# wht = 1/beam.ivar + (fcontam*beam.contam)**2
# wht = np.asarray(1/wht,dtype=np.float32)
# wht[~np.isfinite(wht)] = 0.
contam_weight = np.exp(-(fcontam*np.abs(beam.contam)*np.sqrt(beam.ivar)))
wht = beam.ivar*contam_weight
wht[~np.isfinite(wht)] = 0.
else:
wht = beam.ivar*1
# Convert to f_lambda integrated line fluxes:
# (Inverse of the aXe sensitivity) x (size of pixel in \AA)
sens = np.interp(wave, beam.beam.lam, beam.beam.sensitivity,
left=0, right=0)
dlam = np.interp(wave, beam.beam.lam[1:], np.diff(beam.beam.lam))
# 1e-17 erg/s/cm2 #, scaling closer to e-/s
sens *= 1.e-17
sens *= 1./dlam
if sens == 0:
continue
else:
wht *= sens**2
beam_data /= sens
beam_continuum /= sens
# Go drizzle
# Contamination-cleaned
drizzler(beam_data, beam_wcs, wht, output_wcs,
outsci, outwht, outctx, 1., 'cps', 1,
wcslin_pscale=beam.grism.wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Continuum
drizzler(beam_continuum, beam_wcs, wht, output_wcs,
coutsci, coutwht, coutctx, 1., 'cps', 1,
wcslin_pscale=beam.grism.wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Contamination
drizzler(beam.contam, beam_wcs, wht, output_wcs,
xoutsci, xoutwht, xoutctx, 1., 'cps', 1,
wcslin_pscale=beam.grism.wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Direct thumbnail
filt_i = all_direct_filters[i]
if direct_extension == 'REF':
thumb = beam.direct['REF']
thumb_wht = np.asarray((thumb != 0)*1,dtype=np.float32)
else:
thumb = beam.direct[direct_extension] # /beam.direct.photflam
thumb_wht = 1./(beam.direct.data['ERR']/beam.direct.photflam)**2
thumb_wht[~np.isfinite(thumb_wht)] = 0
if not hasattr(beam.direct.wcs, 'pixel_shape'):
beam.direct.wcs.pixel_shape = (beam.direct.wcs._naxis1,
beam.direct.wcs._naxis2)
if not hasattr(beam.direct.wcs, '_naxis1'):
beam.direct.wcs._naxis1, beam.direct.wcs._naxis2 = beam.direct.wcs._naxis
drizzler(thumb, beam.direct.wcs, thumb_wht, output_wcs,
doutsci[filt_i], doutwht[filt_i], doutctx[filt_i],
1., 'cps', 1,
wcslin_pscale=beam.direct.wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval,
wcsmap=utils.WCSMapAll,
)
# Show in ds9
if ds9 is not None:
ds9.view((outsci-coutsci), header=header)
# Scaling of drizzled outputs
outwht *= (beams[0].grism.wcs.pscale/output_wcs.pscale)**4
coutwht *= (beams[0].grism.wcs.pscale/output_wcs.pscale)**4
xoutwht *= (beams[0].grism.wcs.pscale/output_wcs.pscale)**4
for filt_i in all_direct_filters:
doutwht[filt_i] *= (beams[0].direct.wcs.pscale/output_wcs.pscale)**4
# Make output FITS products
p = pyfits.PrimaryHDU()
p.header['ID'] = (beams[0].id, 'Object ID')
p.header['RA'] = (ra, 'Central R.A.')
p.header['DEC'] = (dec, 'Central Decl.')
p.header['PIXFRAC'] = (pixfrac, 'Drizzle PIXFRAC')
p.header['DRIZKRNL'] = (kernel, 'Drizzle kernel')
p.header['NINPUT'] = (len(beams), 'Number of drizzled beams')
for i, beam in enumerate(beams):
p.header['FILE{0:04d}'.format(i+1)] = (beam.grism.parent_file,
'Parent filename')
p.header['GRIS{0:04d}'.format(i+1)] = (beam.grism.filter,
'Beam grism element')
p.header['PA{0:04d}'.format(i+1)] = (beam.get_dispersion_PA(),
'PA of dispersion axis')
h = header.copy()
h['ID'] = (beam.id, 'Object ID')
h['PIXFRAC'] = (pixfrac, 'Drizzle PIXFRAC')
h['DRIZKRNL'] = (kernel, 'Drizzle kernel')
p.header['NDFILT'] = len(direct_filters), 'Number of direct image filters'
for i, filt_i in enumerate(direct_filters):
p.header['DFILT{0:02d}'.format(i+1)] = filt_i
p.header['NFILT{0:02d}'.format(i+1)] = all_direct_filters.count(filt_i), 'Number of beams with this direct filter'
HDUL = [p]
for i, filt_i in enumerate(direct_filters):
h['FILTER'] = (filt_i, 'Direct image filter')
thumb_sci = pyfits.ImageHDU(data=doutsci[filt_i], header=h,
name='DSCI')
thumb_wht = pyfits.ImageHDU(data=doutwht[filt_i], header=h,
name='DWHT')
thumb_sci.header['EXTVER'] = filt_i
thumb_wht.header['EXTVER'] = filt_i
HDUL += [thumb_sci, thumb_wht]
#thumb_seg = pyfits.ImageHDU(data=seg_slice, header=h, name='DSEG')
h['FILTER'] = (beam.grism.filter, 'Grism filter')
h['WAVELEN'] = (wave, 'Central wavelength')
grism_sci = pyfits.ImageHDU(data=outsci-coutsci, header=h, name='LINE')
grism_cont = pyfits.ImageHDU(data=coutsci, header=h, name='CONTINUUM')
grism_contam = pyfits.ImageHDU(data=xoutsci, header=h, name='CONTAM')
grism_wht = pyfits.ImageHDU(data=outwht, header=h, name='LINEWHT')
#HDUL = [p, thumb_sci, thumb_wht, grism_sci, grism_cont, grism_contam, grism_wht]
HDUL += [grism_sci, grism_cont, grism_contam, grism_wht]
return pyfits.HDUList(HDUL)
[docs]def show_drizzle_HDU(hdu, diff=True, mask_segmentation=True, average_only=False, scale_size=1, cmap='viridis_r', show_labels=True, width_ratio=0.2, **kwargs):
"""Make a figure from the multiple extensions in the drizzled grism file.
Parameters
----------
hdu : `~astropy.io.fits.HDUList`
HDU list output by `drizzle_grisms_and_PAs`.
diff : bool
If True, then plot the stacked spectrum minus the model.
Returns
-------
fig : `~matplotlib.figure.Figure`
The figure.
"""
from collections import OrderedDict
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MultipleLocator
h0 = hdu[0].header
NX = h0['NGRISM']
NY = 0
grisms = OrderedDict()
for ig in range(NX):
g = h0['GRISM{0:03d}'.format(ig+1)]
NY = np.maximum(NY, h0['N'+g])
grisms[g] = h0['N'+g]
NY += 1
widths = []
for i in range(NX):
widths.extend([width_ratio, 1])
if average_only:
NY = 1
fig = plt.figure(figsize=(NX*scale_size/width_ratio,
NY*scale_size+0.33))
gs = GridSpec(NY, NX*2, width_ratios=widths)
else:
fig = plt.figure(figsize=(NX*scale_size/width_ratio, 1*NY*scale_size))
gs = GridSpec(NY, NX*2, height_ratios=[1]*NY, width_ratios=widths)
for ig, g in enumerate(grisms):
sci_i = hdu['SCI', g]
wht_i = hdu['WHT', g]
model_i = hdu['MODEL', g]
kern_i = hdu['KERNEL', g]
h_i = sci_i.header
clip = wht_i.data > 0
if clip.sum() == 0:
clip = np.isfinite(wht_i.data)
avg_rms = 1/np.median(np.sqrt(wht_i.data[clip]))
vmax = np.maximum(1.1*np.percentile(sci_i.data[clip], 98),
5*avg_rms)
vmax_kern = 1.1*np.percentile(kern_i.data, 99.5)
# Kernel
ax = fig.add_subplot(gs[NY-1, ig*2+0])
sh = kern_i.data.shape
extent = [0, sh[1], 0, sh[0]]
ax.imshow(kern_i.data, origin='lower', interpolation='Nearest',
vmin=-0.1*vmax_kern, vmax=vmax_kern, cmap=cmap,
extent=extent, aspect='auto')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.xaxis.set_tick_params(length=0)
ax.yaxis.set_tick_params(length=0)
# Spectrum
sh = sci_i.data.shape
extent = [h_i['WMIN'], h_i['WMAX'], 0, sh[0]]
ax = fig.add_subplot(gs[NY-1, ig*2+1])
if diff:
#print('xx DIFF!')
m = model_i.data
else:
m = 0
ax.imshow(sci_i.data-m, origin='lower',
interpolation='Nearest', vmin=-0.1*vmax, vmax=vmax,
extent=extent, cmap=cmap,
aspect='auto')
ax.set_yticklabels([])
ax.set_xlabel(r'$\lambda$ ($\mu$m) - '+g)
ax.xaxis.set_major_locator(MultipleLocator(GRISM_MAJOR[g]))
if average_only:
iters = []
else:
iters = range(grisms[g])
for ip in iters:
#print(ip, ig)
pa = h0['{0}{1:02d}'.format(g, ip+1)]
sci_i = hdu['SCI', '{0},{1}'.format(g, pa)]
wht_i = hdu['WHT', '{0},{1}'.format(g, pa)]
kern_i = hdu['KERNEL', '{0},{1}'.format(g, pa)]
h_i = sci_i.header
# Kernel
ax = fig.add_subplot(gs[ip, ig*2+0])
sh = kern_i.data.shape
extent = [0, sh[1], 0, sh[0]]
ax.imshow(kern_i.data, origin='lower', interpolation='Nearest',
vmin=-0.1*vmax_kern, vmax=vmax_kern, extent=extent,
cmap=cmap, aspect='auto')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.xaxis.set_tick_params(length=0)
ax.yaxis.set_tick_params(length=0)
# Spectrum
sh = sci_i.data.shape
extent = [h_i['WMIN'], h_i['WMAX'], 0, sh[0]]
ax = fig.add_subplot(gs[ip, ig*2+1])
ax.imshow(sci_i.data, origin='lower',
interpolation='Nearest', vmin=-0.1*vmax, vmax=vmax,
extent=extent, cmap=cmap,
aspect='auto')
ax.set_yticklabels([])
ax.set_xticklabels([])
ax.xaxis.set_major_locator(MultipleLocator(GRISM_MAJOR[g]))
if show_labels:
ax.text(0.015, 0.94, '{0:3.0f}'.format(pa), ha='left',
va='top',
transform=ax.transAxes, fontsize=8,
backgroundcolor='w')
if (ig == (NX-1)) & (ip == 0) & show_labels:
ax.text(0.98, 0.94, 'ID = {0}'.format(h0['ID']),
ha='right', va='top', transform=ax.transAxes,
fontsize=8, backgroundcolor='w')
if average_only:
#pass
gs.tight_layout(fig, pad=0.01)
else:
gs.tight_layout(fig, pad=0.1)
return fig
[docs]def drizzle_2d_spectrum_wcs(beams, data=None, wlimit=[1.05, 1.75], dlam=50,
spatial_scale=1, NY=10, pixfrac=0.6, kernel='square',
convert_to_flambda=True, fcontam=0.2, fill_wht=False,
ds9=None, mask_segmentation=True):
"""Drizzle 2D spectrum from a list of beams
Parameters
----------
beams : list of `~.model.BeamCutout` objects
data : None or list
optionally, drizzle data specified in this list rather than the
contamination-subtracted arrays from each beam.
wlimit : [float, float]
Limits on the wavelength array to drizzle ([wlim, wmax])
dlam : float
Delta wavelength per pixel
spatial_scale : float
Relative scaling of the spatial axis (1 = native pixels)
NY : int
Size of the cutout in the spatial dimension, in output pixels
pixfrac : float
Drizzle PIXFRAC (for `kernel` = 'point')
kernel : str, ('square' or 'point')
Drizzle kernel to use
convert_to_flambda : bool, float
Convert the 2D spectrum to physical units using the sensitivity curves
and if float provided, scale the flux densities by that value
fcontam: float
Factor by which to scale the contamination arrays and add to the
pixel variances.
ds9: `~grizli.ds9.DS9`
Show intermediate steps of the drizzling
Returns
-------
hdu : `~astropy.io.fits.HDUList`
FITS HDUList with the drizzled 2D spectrum and weight arrays
"""
# try:
# import drizzle
# if drizzle.__version__ != '1.12.99':
# # Not the fork that works for all input/output arrays
# raise(ImportError)
#
# #print('drizzle!!')
# from drizzle.dodrizzle import dodrizzle
# drizzler = dodrizzle
# dfillval = '0'
# except:
from drizzlepac import adrizzle
adrizzle.log.setLevel('ERROR')
drizzler = adrizzle.do_driz
dfillval = 0
from stwcs import distortion
from astropy import log
log.setLevel('ERROR')
# log.disable_warnings_logging()
adrizzle.log.setLevel('ERROR')
NX = int(np.round(np.diff(wlimit)[0]*1.e4/dlam)) // 2
center = np.mean(wlimit[:2])*1.e4
out_header, output_wcs = utils.make_spectrum_wcsheader(center_wave=center,
dlam=dlam, NX=NX,
spatial_scale=spatial_scale, NY=NY)
pixscale = 0.128*spatial_scale
# # Get central RA, reference pixel of beam[0]
# #rd = beams[0].get_sky_coords()
# x0 = beams[0].beam.x0.reshape((1,2))
# #x0[0,1] += beam.direct.origin[1]-beam.grism.origin[1]
# rd = beam.grism.wcs.all_pix2world(x0,1)[0]
# theta = 270-beams[0].get_dispersion_PA()
#out_header, output_wcs = utils.make_wcsheader(ra=rd[0], dec=rd[1], size=[50,10], pixscale=pixscale, get_hdu=False, theta=theta)
if True:
theta = -np.arctan2(np.diff(beams[0].beam.ytrace)[0], 1)
undist_wcs = distortion.utils.output_wcs([beams[0].grism.wcs], undistort=True)
undist_wcs = utils.transform_wcs(undist_wcs, rotation=theta, scale=undist_wcs.pscale/pixscale)
output_wcs = undist_wcs.copy()
out_header = utils.to_header(output_wcs)
# Direct image
d_undist_wcs = distortion.utils.output_wcs([beams[0].direct.wcs], undistort=True)
d_undist_wcs = utils.transform_wcs(d_undist_wcs, rotation=0., scale=d_undist_wcs.pscale/pixscale)
d_output_wcs = d_undist_wcs.copy()
# Make square
if hasattr(d_output_wcs, '_naxis1'):
nx1, nx2 = d_output_wcs._naxis1, d_output_wcs._naxis2
else:
nx1, nx2 = d_output_wcs._naxis
d_output_wcs._naxis1, d_output_wcs._naxis2 = nx1, nx2
dx = nx1 - nx2
if hasattr(d_output_wcs, '_naxis1'):
d_output_wcs._naxis1 = d_output_wcs._naxis2
else:
d_output_wcs._naxis[0] = d_output_wcs._naxis[1]
d_output_wcs._naxis1 = d_output_wcs._naxis2 = d_output_wcs._naxis[0]
d_output_wcs.wcs.crpix[0] -= dx/2.
d_out_header = utils.to_header(d_output_wcs)
#delattr(output_wcs, 'orientat')
#beam_header = utils.to_header(beam_wcs)
#output_wcs = beam_wcs
#output_wcs = pywcs.WCS(beam_header, relax=True)
#output_wcs.pscale = utils.get_wcs_pscale(output_wcs)
# shift CRPIX to reference position of beam[0]
sh = (out_header['NAXIS2'], out_header['NAXIS1'])
sh_d = (d_out_header['NAXIS2'], d_out_header['NAXIS1'])
outsci = np.zeros(sh, dtype=np.float32)
outwht = np.zeros(sh, dtype=np.float32)
outctx = np.zeros(sh, dtype=np.int32)
doutsci = np.zeros(sh_d, dtype=np.float32)
doutwht = np.zeros(sh_d, dtype=np.float32)
doutctx = np.zeros(sh_d, dtype=np.int32)
outvar = np.zeros(sh, dtype=np.float32)
outwv = np.zeros(sh, dtype=np.float32)
outcv = np.zeros(sh, dtype=np.int32)
outls = np.zeros(sh, dtype=np.float32)
outlw = np.zeros(sh, dtype=np.float32)
outlc = np.zeros(sh, dtype=np.int32)
if data is None:
data = []
for i, beam in enumerate(beams):
# Contamination-subtracted
beam_data = beam.grism.data['SCI'] - beam.contam
data.append(beam_data)
for i, beam in enumerate(beams):
# Get specific WCS for each beam
beam_header, beam_wcs = beam.get_2d_wcs()
beam_wcs = beam.grism.wcs.deepcopy()
# Shift SIP reference
dx_sip = beam.grism.origin[1] - beam.direct.origin[1]
#beam_wcs.sip.crpix[0] += dx_sip
for wcs_ext in [beam_wcs.sip]:
if wcs_ext is not None:
wcs_ext.crpix[0] += dx_sip
for wcs_ext in [beam_wcs.cpdis1, beam_wcs.cpdis2, beam_wcs.det2im1, beam_wcs.det2im2]:
if wcs_ext is not None:
wcs_ext.crval[0] += dx_sip
# Shift y for trace
xy0 = beam.grism.wcs.all_world2pix(output_wcs.wcs.crval.reshape((1, 2)), 0)[0]
dy = np.interp(xy0[0], np.arange(beam.beam.sh_beam[1]), beam.beam.ytrace)
#beam_wcs.sip.crpix[1] += dy
beam_wcs.wcs.crpix[1] += dy
for wcs_ext in [beam_wcs.sip]:
if wcs_ext is not None:
wcs_ext.crpix[1] += dy
for wcs_ext in [beam_wcs.cpdis1, beam_wcs.cpdis2, beam_wcs.det2im1, beam_wcs.det2im2]:
if wcs_ext is not None:
wcs_ext.crval[1] += dy
if not hasattr(beam_wcs, 'pixel_shape'):
beam_wcs.pixel_shape = beam_wcs._naxis1, beam_wcs._naxis2
if not hasattr(beam_wcs, '_naxis1'):
beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis
d_beam_wcs = beam.direct.wcs
if beam.direct['REF'] is None:
d_wht = 1./beam.direct['ERR']**2
d_wht[~np.isfinite(d_wht)] = 0
d_sci = beam.direct['SCI']*1
else:
d_sci = beam.direct['REF']*1
d_wht = d_sci*0.+1
if mask_segmentation:
d_sci *= (beam.beam.seg == beam.id)
# Downweight contamination
# wht = 1/beam.ivar + (fcontam*beam.contam)**2
# wht = np.asarray(1/wht,dtype=np.float32)
# wht[~np.isfinite(wht)] = 0.
contam_weight = np.exp(-(fcontam*np.abs(beam.contam)*np.sqrt(beam.ivar)))
wht = beam.ivar*contam_weight
wht[~np.isfinite(wht)] = 0.
contam_weight[beam.ivar == 0] = 0
data_i = data[i]*1.
scl = 1.
if convert_to_flambda:
#data_i *= convert_to_flambda/beam.beam.sensitivity
#wht *= (beam.beam.sensitivity/convert_to_flambda)**2
scl = convert_to_flambda # /1.e-17
scl *= 1./beam.flat_flam.reshape(beam.beam.sh_beam).sum(axis=0)
#scl = convert_to_flambda/beam.beam.sensitivity
data_i *= scl
wht *= (1/scl)**2
#contam_weight *= scl
wht[~np.isfinite(data_i+scl)] = 0
contam_weight[~np.isfinite(data_i+scl)] = 0
data_i[~np.isfinite(data_i+scl)] = 0
# Go drizzle
data_wave = np.dot(np.ones(beam.beam.sh_beam[0])[:, None], beam.beam.lam[None, :])
drizzler(data_wave, beam_wcs, wht*0.+1, output_wcs,
outls, outlw, outlc, 1., 'cps', 1,
wcslin_pscale=1., uniqid=1,
pixfrac=1, kernel='square', fillval=dfillval)
# Direct image
drizzler(d_sci, d_beam_wcs, d_wht, d_output_wcs,
doutsci, doutwht, doutctx, 1., 'cps', 1,
wcslin_pscale=d_beam_wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval)
# Contamination-cleaned
drizzler(data_i, beam_wcs, wht, output_wcs,
outsci, outwht, outctx, 1., 'cps', 1,
wcslin_pscale=beam_wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval)
# For variance
drizzler(contam_weight, beam_wcs, wht, output_wcs,
outvar, outwv, outcv, 1., 'cps', 1,
wcslin_pscale=beam_wcs.pscale, uniqid=1,
pixfrac=pixfrac, kernel=kernel, fillval=dfillval)
if ds9 is not None:
ds9.view(outsci, header=out_header)
# if True:
# w, f, e = beam.beam.optimal_extract(data_i, ivar=beam.ivar)
# plt.scatter(w, f, marker='.', color='k', alpha=0.5)
# Correct for drizzle scaling
#outsci /= output_wcs.pscale**2
outls /= output_wcs.pscale**2
wave = np.median(outls, axis=0)
# # Testing
# fl = (sp[1].data*mask).sum(axis=0)
# variance
outvar /= outwv # *output_wcs.pscale**2
outwht = 1/outvar
outwht[(outvar == 0) | (~np.isfinite(outwht))] = 0
# return outwht, outsci, outvar, outwv, output_wcs.pscale
p = pyfits.PrimaryHDU()
p.header['ID'] = (beams[0].id, 'Object ID')
p.header['WMIN'] = (wave[0], 'Minimum wavelength')
p.header['WMAX'] = (wave[-1], 'Maximum wavelength')
p.header['DLAM'] = ((wave[-1]-wave[0])/wave.size, 'Delta wavelength')
p.header['FCONTAM'] = (fcontam, 'Contamination weight')
p.header['PIXFRAC'] = (pixfrac, 'Drizzle PIXFRAC')
p.header['DRIZKRNL'] = (kernel, 'Drizzle kernel')
p.header['NINPUT'] = (len(beams), 'Number of drizzled beams')
for i, beam in enumerate(beams):
p.header['FILE{0:04d}'.format(i+1)] = (beam.grism.parent_file,
'Parent filename')
p.header['GRIS{0:04d}'.format(i+1)] = (beam.grism.filter,
'Beam grism element')
h = out_header.copy()
for k in p.header:
h[k] = p.header[k]
direct_sci = pyfits.ImageHDU(data=doutsci, header=d_out_header, name='DSCI')
grism_sci = pyfits.ImageHDU(data=outsci, header=h, name='SCI')
grism_wht = pyfits.ImageHDU(data=outwht, header=h, name='WHT')
hdul = pyfits.HDUList([p, grism_sci, grism_wht, direct_sci])
return hdul