"""
Roman pipeline step for image alignment.
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
from astropy.table import Table
from roman_datamodels import datamodels as rdm
from stcal.tweakreg import tweakreg
from stcal.tweakreg.tweakreg import _SINGLE_GROUP_REFCAT_STR, SINGLE_GROUP_REFCAT
from romancal.assign_wcs.utils import add_s_region
from romancal.lib.save_wcs import save_wfiwcs
# LOCAL
from ..datamodels import ModelLibrary
from ..stpipe import RomanStep
if TYPE_CHECKING:
from typing import ClassVar
DEFAULT_ABS_REFCAT = SINGLE_GROUP_REFCAT[0]
__all__ = ["TweakRegStep"]
[docs]
class TweakRegStep(RomanStep):
"""
TweakRegStep: Image alignment based on catalogs of sources detected in
input images.
"""
class_alias = "tweakreg"
spec = f"""
use_custom_catalogs = boolean(default=False) # Use custom user-provided catalogs?
catalog_format = string(default='ascii.ecsv') # Catalog output file format
catfile = string(default='') # Name of the file with a list of custom user-provided catalogs
catalog_path = string(default='') # Catalog output file path
enforce_user_order = boolean(default=False) # Align images in user specified order?
expand_refcat = boolean(default=False) # Expand reference catalog with new sources?
minobj = integer(default=15) # Minimum number of objects acceptable for matching
searchrad = float(default=2.0) # The search radius in arcsec for a match
use2dhist = boolean(default=True) # Use 2d histogram to find initial offset?
separation = float(default=1.0) # Minimum object separation in arcsec
tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec
fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift') # Fitting geometry
nclip = integer(min=0, default=3) # Number of clipping iterations in fit
sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units
abs_refcat = string(default='{DEFAULT_ABS_REFCAT}') # Absolute reference
# catalog. Options: {_SINGLE_GROUP_REFCAT_STR}
save_abs_catalog = boolean(default=False) # Write out used absolute astrometric reference catalog as a separate product
abs_minobj = integer(default=15) # Minimum number of objects acceptable for matching when performing absolute astrometry
abs_searchrad = float(default=6.0) # The search radius in arcsec for a match when performing absolute astrometry
# We encourage setting this parameter to True. Otherwise, xoffset and yoffset will be set to zero.
abs_use2dhist = boolean(default=True) # Use 2D histogram to find initial offset when performing absolute astrometry?
abs_separation = float(default=1.0) # Minimum object separation in arcsec when performing absolute astrometry
abs_tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec when performing absolute astrometry
# Fitting geometry when performing absolute astrometry
abs_fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift')
abs_nclip = integer(min=0, default=3) # Number of clipping iterations in fit when performing absolute astrometry
abs_sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units when performing absolute astrometry
output_use_model = boolean(default=True) # When saving use `DataModel.meta.filename`
update_source_catalog_coordinates = boolean(default=False) # Update source catalog file with tweaked coordinates?
save_l1_wcs = boolean(default=True)
"""
reference_file_types: ClassVar = []
[docs]
def process(self, input):
# properly handle input
try:
if isinstance(input, rdm.DataModel):
images = ModelLibrary([input])
elif str(input).endswith(".asdf"):
images = ModelLibrary([rdm.open(input)])
elif isinstance(input, ModelLibrary):
images = input
else:
images = ModelLibrary(input)
except TypeError as e:
e.args = (
"Input to tweakreg must be a list of DataModels, an "
"association, or an already open ModelLibrary "
"containing one or more DataModels.",
) + e.args[1:]
raise e
if not images:
raise ValueError("Input must contain at least one image model.")
self.log.info(
f"Number of image groups to be aligned: {len(images.group_indices):d}."
)
self.log.info("Image groups:")
for name in images.group_names:
self.log.info(f" {name}")
# set the first image as reference
with images:
ref_image = images.borrow(0)
images.shelve(ref_image, 0, modify=False)
catdict = _parse_catfile(self.catfile)
use_custom_catalogs = self.use_custom_catalogs
# if user requested the use of custom catalogs and provided a
# valid 'catfile' file name that has no custom catalogs,
# turn off the use of custom catalogs:
if catdict is not None and not catdict:
self.log.warning(
"'use_custom_catalogs' is set to True but 'catfile' "
"contains no user catalogs."
)
use_custom_catalogs = False
if use_custom_catalogs and catdict:
with images:
for i, member in enumerate(images.asn["products"][0]["members"]):
filename = member["expname"]
if filename in catdict:
# FIXME: I'm not sure if this captures all the possible combinations
# for example, meta.tweakreg_catalog is set by the container (when
# it's present in the association). However the code in this step
# checks meta.source_catalog.tweakreg_catalog. I think this means
# that setting a catalog via an association does not work. Is this
# intended? If so, the container can be updated to not support that.
model = images.borrow(i)
model.meta["source_catalog"] = {
"tweakreg_catalog_name": catdict[filename],
}
images.shelve(model, i)
else:
images.shelve(model, i, modify=False)
# set path where the source catalog will be saved to
if len(self.catalog_path) == 0:
self.catalog_path = os.getcwd()
self.catalog_path = Path(self.catalog_path).as_posix()
self.log.info(f"All source catalogs will be saved to: {self.catalog_path}")
# set reference catalog name
if not self.abs_refcat:
self.abs_refcat = DEFAULT_ABS_REFCAT.strip().upper()
if self.abs_refcat != DEFAULT_ABS_REFCAT:
self.expand_refcat = True
# build the catalogs for input images
imcats = []
with images:
for i, image_model in enumerate(images):
exposure_type = image_model.meta.exposure.type
if exposure_type != "WFI_IMAGE":
self.log.info("Skipping TweakReg for spectral exposure.")
image_model.meta.cal_step.tweakreg = "SKIPPED"
else:
source_catalog = getattr(image_model.meta, "source_catalog", None)
if source_catalog is None:
images.shelve(image_model, i, modify=False)
raise AttributeError(
"Attribute 'meta.source_catalog' is missing. "
"Please either run SourceCatalogStep or provide a custom source catalog."
)
try:
catalog = self.get_tweakreg_catalog(source_catalog, image_model)
except AttributeError as e:
self.log.error(f"Failed to retrieve tweakreg_catalog: {e}")
images.shelve(image_model, i, modify=False)
raise e
try:
# validate catalog columns
_validate_catalog_columns(catalog)
except ValueError as e:
self.log.error(f"Failed to validate catalog columns: {e}")
images.shelve(image_model, i, modify=False)
raise e
catalog = tweakreg.filter_catalog_by_bounding_box(
catalog, image_model.meta.wcs.bounding_box
)
if self.save_abs_catalog:
output_name = os.path.join(
self.catalog_path, f"fit_{self.abs_refcat.lower()}_ref.ecsv"
)
catalog.write(
output_name, format=self.catalog_format, overwrite=True
)
image_model.meta["tweakreg_catalog"] = catalog.as_array()
nsources = len(catalog)
self.log.info(
f"Detected {nsources} sources in {image_model.meta.filename}."
if nsources
else f"No sources found in {image_model.meta.filename}."
)
# build image catalog
# catalog name
catalog_name = os.path.splitext(image_model.meta.filename)[0].strip(
"_- "
)
# catalog data
catalog_table = Table(image_model.meta.tweakreg_catalog)
catalog_table.meta["name"] = catalog_name
imcat = tweakreg.construct_wcs_corrector(
wcs=image_model.meta.wcs,
refang=image_model.meta.wcsinfo,
catalog=catalog_table,
group_id=images._model_to_group_id(image_model),
)
imcat.meta["model_index"] = i
imcats.append(imcat)
images.shelve(image_model, i)
# run alignment only if it was possible to build image catalogs
if len(imcats):
# extract WCS correctors to use for image alignment
if len(images.group_indices) > 1:
self.do_relative_alignment(imcats)
if self.abs_refcat in SINGLE_GROUP_REFCAT:
self.do_absolute_alignment(ref_image, imcats)
# finalize step
with images:
for imcat in imcats:
image_model = images.borrow(imcat.meta["model_index"])
image_model.meta.cal_step["tweakreg"] = "COMPLETE"
# remove source catalog
del image_model.meta["tweakreg_catalog"]
# retrieve fit status and update wcs if fit is successful:
if "SUCCESS" in imcat.meta.get("fit_info")["status"]:
# Update/create the WCS .name attribute with information
# on this astrometric fit as the only record that it was
# successful:
# NOTE: This .name attrib agreed upon by the JWST Cal
# Working Group.
# Current value is merely a place-holder based
# on HST conventions. This value should also be
# translated to the FITS WCSNAME keyword
# IF that is what gets recorded in the archive
# for end-user searches.
imcat.wcs.name = f"FIT-LVL2-{self.abs_refcat}"
# serialize object from tweakwcs
# (typecasting numpy objects to python types so that it doesn't cause an
# issue when saving datamodel to ASDF)
wcs_fit_results = {
k: (
v.tolist()
if isinstance(v, np.ndarray | np.bool_)
else v
)
for k, v in imcat.meta["fit_info"].items()
}
# add fit results and new WCS to datamodel
image_model.meta["wcs_fit_results"] = wcs_fit_results
# remove unwanted keys from WCS fit results
for k in [
"eff_minobj",
"matched_ref_idx",
"matched_input_idx",
"fit_RA",
"fit_DEC",
"fitmask",
]:
del image_model.meta["wcs_fit_results"][k]
# update WCS
image_model.meta.wcs = imcat.wcs
# update S_REGION
add_s_region(image_model)
images.shelve(image_model, imcat.meta["model_index"])
# Write out the WfiWcs products
if self.save_l1_wcs:
save_wfiwcs(self, images, force=True)
return images
[docs]
def update_catalog_coordinates(self, tweakreg_catalog_name, tweaked_wcs):
"""
Update the source catalog coordinates using the tweaked WCS.
Parameters
----------
tweakreg_catalog_name : str
The name of the TweakReg catalog file produced by `SourceCatalog`.
tweaked_wcs : `gwcs.wcs.WCS`
The tweaked World Coordinate System (WCS) object.
Returns
-------
None
"""
# read in cat file
with rdm.open(tweakreg_catalog_name) as source_catalog_model:
# get catalog
catalog = source_catalog_model.source_catalog
# define mapping between pixel and world coordinates
colname_mapping = {
("xcentroid", "ycentroid"): ("ra_centroid", "dec_centroid"),
("x_psf", "y_psf"): ("ra_psf", "dec_psf"),
}
for k, v in colname_mapping.items():
# get column names
x_colname, y_colname = k
ra_colname, dec_colname = v
# calculate new coordinates using tweaked WCS and update catalog coordinates
catalog[ra_colname], catalog[dec_colname] = tweaked_wcs(
catalog[x_colname], catalog[y_colname]
)
# save updated catalog (overwrite cat file)
self.save_model(
source_catalog_model,
output_file=source_catalog_model.meta.filename,
suffix="cat",
force=True,
)
[docs]
def read_catalog(self, catalog_name):
"""
Reads a source catalog from a specified file.
This function determines the format of the catalog based on the
file extension:
* "asdf": uses roman datamodels
* "parquet": uses pyarrow
* otherwise: uses astropy Table.
Parameters
----------
catalog_name : str
The name of the catalog file to read.
Returns
-------
Table
The read catalog as a Table object.
Raises
------
ValueError
If the catalog format is unsupported.
"""
filetype = (
"parquet" if catalog_name.endswith("parquet") else self.catalog_format
)
if catalog_name.endswith("asdf"):
# leave this for now
with rdm.open(catalog_name) as source_catalog_model:
catalog = source_catalog_model.source_catalog
else:
catalog = Table.read(catalog_name, format=filetype)
return catalog
[docs]
def get_tweakreg_catalog(self, source_catalog, image_model):
"""
Retrieve the tweakreg catalog from source detection.
This method checks the source detection metadata for the presence of a
tweakreg catalog data or a string with its name. It returns the catalog
as a Table object if either is found, or raises an error if neither is available.
Parameters
----------
source_catalog : object
The source catalog metadata containing catalog information.
image_model : DataModel
The image model associated with the source detection.
Returns
-------
Table
The retrieved tweakreg catalog as a Table object.
Raises
------
AttributeError
If the required catalog information is missing from the source detection.
"""
if getattr(source_catalog, "tweakreg_catalog", None):
tweakreg_catalog = Table(np.asarray(source_catalog.tweakreg_catalog))
del image_model.meta.source_catalog["tweakreg_catalog"]
return tweakreg_catalog
if getattr(source_catalog, "tweakreg_catalog_name", None):
return self.read_catalog(source_catalog.tweakreg_catalog_name)
raise AttributeError(
"Attribute 'meta.source_catalog.tweakreg_catalog' is missing. "
"Please either run SourceCatalogStep or provide a custom source catalog."
)
[docs]
def do_relative_alignment(self, imcats):
"""
Perform relative alignment of images.
This method performs relative alignment with the specified parameters,
including search radius, separation, and fitting geometry.
Parameters
----------
imcats : list
A list of image catalogs containing source information for alignment.
Returns
-------
None
"""
tweakreg.relative_align(
imcats,
searchrad=self.searchrad,
separation=self.separation,
use2dhist=self.use2dhist,
tolerance=self.tolerance,
xoffset=0,
yoffset=0,
enforce_user_order=self.enforce_user_order,
expand_refcat=self.expand_refcat,
minobj=self.minobj,
fitgeometry=self.fitgeometry,
nclip=self.nclip,
sigma=self.sigma,
clip_accum=True,
)
[docs]
def do_absolute_alignment(self, ref_image, imcats):
"""
Perform absolute alignment of images.
This method retrieves a reference image and performs absolute alignment
using the specified parameters, including reference WCS information and
catalog details. It aligns the provided image catalogs to the absolute
reference catalog.
Parameters
----------
ref_image : DataModel
The reference image used for alignment, which contains WCS information.
imcats : list
A list of image catalogs containing source information for alignment.
Returns
-------
None
"""
tweakreg.absolute_align(
imcats,
self.abs_refcat,
ref_wcs=ref_image.meta.wcs,
ref_wcsinfo=ref_image.meta.wcsinfo,
epoch=ref_image.meta.exposure.start_time.decimalyear,
abs_minobj=self.abs_minobj,
abs_fitgeometry=self.abs_fitgeometry,
abs_nclip=self.abs_nclip,
abs_sigma=self.abs_sigma,
abs_searchrad=self.abs_searchrad,
abs_use2dhist=self.abs_use2dhist,
abs_separation=self.abs_separation,
abs_tolerance=self.abs_tolerance,
save_abs_catalog=self.save_abs_catalog,
abs_catalog_output_dir=self.output_dir,
clip_accum=True,
)
def _parse_catfile(catfile):
"""
Parse a catalog file and return a dictionary mapping data models to catalog paths.
This function reads a specified catalog file, extracting data model names and
their associated catalog paths. It supports a format where each line contains
a data model followed by an optional catalog path, and it ensures that the
file adheres to the expected structure.
Parameters
----------
catfile : str
The path to the catalog file to be parsed.
Returns
-------
dict or None
A dictionary mapping data model names to catalog paths, or None if the
input file is empty or invalid.
Raises
------
ValueError
If the catalog file contains more than two columns per line.
"""
if catfile is None or not catfile.strip():
return None
catdict = {}
with open(catfile) as f:
catfile_dir = os.path.dirname(catfile)
for line in f:
sline = line.strip()
if not sline or sline[0] == "#":
continue
data_model, *catalog = sline.split()
catalog = list(map(str.strip, catalog))
if len(catalog) == 1:
catdict[data_model] = os.path.join(catfile_dir, catalog[0])
elif not catalog:
catdict[data_model] = None
else:
raise ValueError("'catfile' can contain at most two columns.")
return catdict
def _validate_catalog_columns(catalog):
"""
Validate the presence of required columns in the catalog.
This method checks if the specified axis column exists in the catalog.
If the axis is not found, it looks for a corresponding psf column
and renames it if present. If neither is found, it raises an error.
Parameters
----------
catalog : Table
The catalog to validate, which should contain source information.
axis : str
The axis to check for in the catalog (e.g., 'x' or 'y').
Returns
-------
None
Raises
------
ValueError
If the required columns are missing from the catalog.
"""
for axis in ["x", "y"]:
if axis not in catalog.colnames:
long_axis = f"{axis}_psf"
if long_axis in catalog.colnames:
catalog.rename_column(long_axis, axis)
else:
raise ValueError(
"'tweakreg' source catalogs must contain a header with "
"columns named either 'x' and 'y' or 'x_psf' and 'y_psf'."
)
return catalog