Source code for optika.distortion._distortion

import abc
import dataclasses
import functools
import numpy as np
import matplotlib.axes
import matplotlib.cm
import matplotlib.colors
import matplotlib.figure
import matplotlib.pyplot as plt
import astropy.units as u
import astropy.visualization
import named_arrays as na
import optika

__all__ = [
    "AbstractDistortionModel",
    "AbstractLinearDistortionModel",
    "SimpleDistortionModel",
    "AbstractInterpolatedDistortionModel",
    "PolynomialDistortionModel",
]


[docs] @dataclasses.dataclass(eq=False, repr=False) class AbstractDistortionModel( optika.mixins.Printable, ): """ An interface describing an arbitrary distortion model, which maps scene coordinates to sensor coordinates (and vice versa). A distortion model carries the wavelength along with the position, since the mapping from a point in the scene to a point on the sensor generally depends on wavelength (for example, the dispersion of a spectrograph). As a result :meth:`distort` and :meth:`undistort` are inverses of one another only up to the accuracy of the model. """
[docs] @abc.abstractmethod def distort( self, coordinates: na.AbstractSpectralPositionalVectorArray, ) -> na.SpectralPositionalVectorArray: """ Convert scene coordinates to sensor coordinates. Parameters ---------- coordinates The wavelength and position of each point in the scene. """
[docs] @abc.abstractmethod def undistort( self, coordinates: na.AbstractSpectralPositionalVectorArray, ) -> na.SpectralPositionalVectorArray: """ Convert sensor coordinates to scene coordinates. Parameters ---------- coordinates The wavelength and sensor position of each point. """
[docs] @dataclasses.dataclass(eq=False, repr=False) class AbstractLinearDistortionModel( AbstractDistortionModel, ): r""" A distortion model which is an affine transformation of the scene coordinates, .. math:: \text{distort}(\vec{c}) = \mathbf{M} \, (\vec{c} - \vec{c}_0) + \vec{b}, where :math:`\mathbf{M}` is :attr:`matrix`, :math:`\vec{c}_0` is :attr:`center`, and :math:`\vec{b}` is :attr:`intercept`. Since the transformation is linear, :meth:`undistort` is its *exact* inverse (unlike a polynomial fit). """ @property @abc.abstractmethod def matrix(self) -> na.AbstractSpectralPositionalMatrixArray: """The linear part of the affine transformation.""" @property @abc.abstractmethod def center(self) -> na.AbstractSpectralPositionalVectorArray: """The reference point subtracted from the coordinates before applying :attr:`matrix`.""" @property @abc.abstractmethod def intercept(self) -> na.AbstractSpectralPositionalVectorArray: """The constant offset added after applying :attr:`matrix`."""
[docs] def distort( self, coordinates: na.AbstractSpectralPositionalVectorArray, ) -> na.SpectralPositionalVectorArray: return self.matrix @ (coordinates - self.center) + self.intercept
[docs] def undistort( self, coordinates: na.AbstractSpectralPositionalVectorArray, ) -> na.SpectralPositionalVectorArray: return self.matrix.inverse @ (coordinates - self.intercept) + self.center
[docs] @dataclasses.dataclass(eq=False, repr=False) class SimpleDistortionModel( AbstractLinearDistortionModel, ): r""" A simple analytic distortion model consisting of a rotation of the field, an isotropic plate scale, and a linear spectral dispersion along the rotated :math:`x` axis. This captures the distortion of an idealized spectrograph: the field center at the :attr:`reference` wavelength maps to the :attr:`reference` position on the sensor, and other wavelengths are displaced along the dispersion direction. Examples -------- Distort a grid of scene coordinates and plot the result on the sensor, colored by wavelength. .. jupyter-execute:: import matplotlib.pyplot as plt import astropy.units as u import named_arrays as na import optika model = optika.distortion.SimpleDistortionModel( plate_scale=1 * u.arcsec / u.pix, dispersion=2 * u.nm / u.pix, angle=15 * u.deg, reference=na.SpectralPositionalVectorArray( wavelength=550 * u.nm, position=na.Cartesian2dVectorArray(0, 0) * u.pix, ), ) scene = na.SpectralPositionalVectorArray( wavelength=na.linspace(500, 600, axis="wavelength", num=3) * u.nm, position=na.Cartesian2dVectorLinearSpace( start=-10 * u.arcsec, stop=+10 * u.arcsec, axis=na.Cartesian2dVectorArray("field_x", "field_y"), num=5, ), ) sensor = model.distort(scene) fig, ax = plt.subplots(constrained_layout=True) ax.set_aspect("equal") for wavelength in scene.wavelength.ndarray: na.plt.scatter( sensor.position.x, sensor.position.y, where=scene.wavelength == wavelength, label=f"{wavelength}", ax=ax, ) ax.set_xlabel(f"detector $x$ ({na.unit(sensor.position.x):latex_inline})") ax.set_ylabel(f"detector $y$ ({na.unit(sensor.position.y):latex_inline})") ax.legend(); """ plate_scale: u.Quantity | na.AbstractScalar = dataclasses.MISSING """The spatial plate scale, in units such as :math:`\\text{arcsec} / \\text{pix}`.""" dispersion: u.Quantity | na.AbstractScalar = dataclasses.MISSING """The magnitude of the spectral dispersion, in units such as :math:`\\text{nm} / \\text{pix}`.""" angle: u.Quantity | na.AbstractScalar = dataclasses.MISSING """The angle of the dispersion direction with respect to the scene.""" reference: na.AbstractSpectralPositionalVectorArray = dataclasses.MISSING """The reference wavelength and the sensor position that the field center maps to at that wavelength.""" @functools.cached_property def matrix(self) -> na.SpectralPositionalMatrixArray: cos = np.cos(self.angle) sin = np.sin(self.angle) plate_scale = self.plate_scale dispersion = self.dispersion unit_wavelength = na.unit(self.reference.wavelength) return na.SpectralPositionalMatrixArray( wavelength=na.SpectralPositionalVectorArray( wavelength=1, position=na.Cartesian2dVectorArray( x=0 * unit_wavelength / u.arcsec, y=0 * unit_wavelength / u.arcsec, ), ), position=na.Cartesian2dMatrixArray( x=na.SpectralPositionalVectorArray( wavelength=1 / dispersion, position=na.Cartesian2dVectorArray( x=cos / plate_scale, y=-sin / plate_scale, ), ), y=na.SpectralPositionalVectorArray( wavelength=0 / dispersion, position=na.Cartesian2dVectorArray( x=sin / plate_scale, y=cos / plate_scale, ), ), ), ) @property def center(self) -> na.SpectralPositionalVectorArray: return na.SpectralPositionalVectorArray( wavelength=self.reference.wavelength, position=na.Cartesian2dVectorArray(0, 0) * u.arcsec, ) @property def intercept(self) -> na.AbstractSpectralPositionalVectorArray: return self.reference
[docs] @dataclasses.dataclass(eq=False, repr=False) class AbstractInterpolatedDistortionModel( AbstractDistortionModel, ): """ A distortion model defined by interpolating between known scene/sensor coordinates. This class has two main members, :attr:`coordinates_scene` and :attr:`coordinates_sensor`, the calibration points between which subclasses interpolate. """ @property @abc.abstractmethod def coordinates_scene(self) -> na.AbstractSpectralPositionalVectorArray: """ The wavelength and position of each calibration point in the scene. """ @property @abc.abstractmethod def coordinates_sensor(self) -> na.AbstractCartesian2dVectorArray: """ The position of each calibration point mapped onto the sensor. """ @property @abc.abstractmethod def axis_wavelength(self) -> str: """The logical axis corresponding to changing wavelength.""" @property @abc.abstractmethod def axis_field(self) -> tuple[str, str]: """The logical axes corresponding to changing position in the scene."""
[docs] @dataclasses.dataclass(eq=False, repr=False) class PolynomialDistortionModel( AbstractInterpolatedDistortionModel, ): """ A distortion model which fits a polynomial to known scene/sensor coordinates. The forward model (:meth:`distort`) is a polynomial fit mapping scene position to sensor position as a function of wavelength. The inverse model (:meth:`undistort`) is a *separate* polynomial fit in the opposite direction, so the round trip is exact only to the accuracy of the two fits. Examples -------- Plot the fit residual of a distortion model with a deliberately underfit (linear) polynomial. .. jupyter-execute:: import numpy as np import astropy.units as u import named_arrays as na import optika scene = na.SpectralPositionalVectorArray( wavelength=na.linspace(500, 600, axis="wavelength", num=3) * u.nm, position=na.Cartesian2dVectorLinearSpace( start=-1 * u.deg, stop=+1 * u.deg, axis=na.Cartesian2dVectorArray("field_x", "field_y"), num=13, ), ) sensor = na.Cartesian2dVectorArray( x=scene.position.x * (10 * u.mm / u.deg) + scene.position.x**2 * (1 * u.mm / u.deg**2), y=scene.position.y * (10 * u.mm / u.deg) + scene.position.y**2 * (1 * u.mm / u.deg**2), ) model = optika.distortion.PolynomialDistortionModel( coordinates_scene=scene, coordinates_sensor=sensor, axis_wavelength="wavelength", axis_field=("field_x", "field_y"), degree=1, ) fig, ax = model.plot_residual() na.plt.set_aspect("equal", ax=ax); """ coordinates_scene: na.AbstractSpectralPositionalVectorArray = dataclasses.MISSING """The wavelength and position of each calibration point in the scene.""" coordinates_sensor: na.AbstractCartesian2dVectorArray = dataclasses.MISSING """The position of each calibration point mapped onto the sensor.""" axis_wavelength: str = dataclasses.MISSING """The logical axis corresponding to changing wavelength.""" axis_field: tuple[str, str] = dataclasses.MISSING """The logical axes corresponding to changing position in the scene.""" degree: int = 1 """The degree of the polynomial used to model the distortion.""" where: bool | na.AbstractScalar = True """A boolean mask selecting which calibration points to use for fitting.""" @property def _axis_scene(self) -> tuple[str, ...]: """The logical axes over which the calibration points are distributed.""" return (self.axis_wavelength, *self.axis_field) @functools.cached_property def fit(self) -> na.PolynomialFitFunctionArray: """The polynomial fit mapping scene position to sensor position.""" scene = self.coordinates_scene return na.PolynomialFitFunctionArray( inputs=scene, outputs=self.coordinates_sensor, center=scene.mean(self._axis_scene), degree=self.degree, where_polynomial=self.where, ) @functools.cached_property def fit_inverse(self) -> na.PolynomialFitFunctionArray: """The polynomial fit mapping sensor position back to scene position.""" scene = self.coordinates_scene inputs = na.SpectralPositionalVectorArray( wavelength=scene.wavelength, position=self.coordinates_sensor, ) return na.PolynomialFitFunctionArray( inputs=inputs, outputs=scene.position, center=inputs.mean(self._axis_scene), degree=self.degree, where_polynomial=self.where, )
[docs] def distort( self, coordinates: na.AbstractSpectralPositionalVectorArray, ) -> na.SpectralPositionalVectorArray: return na.SpectralPositionalVectorArray( wavelength=coordinates.wavelength, position=self.fit(coordinates).outputs, )
[docs] def undistort( self, coordinates: na.AbstractSpectralPositionalVectorArray, ) -> na.SpectralPositionalVectorArray: return na.SpectralPositionalVectorArray( wavelength=coordinates.wavelength, position=self.fit_inverse(coordinates).outputs, )
[docs] def plot_residual( self, figsize: None | tuple[float, float] = None, cmap: None | str | matplotlib.colors.Colormap = None, vmin: None | na.ArrayLike = None, vmax: None | na.ArrayLike = None, **kwargs, ) -> tuple[matplotlib.figure.Figure, na.ScalarArray]: """ Plot the residual of the forward :attr:`fit` as a function of field angle, with a separate subplot for each wavelength. The residual is the magnitude of the difference between the calibration sensor positions, :attr:`coordinates_sensor`, and the positions predicted by the forward polynomial fit. Parameters ---------- figsize The size of the returned figure in inches. If :obj:`None`, the size is chosen automatically from the number of wavelengths and the aspect ratio of the field of view. cmap The colormap used to map the residual magnitude to colors. vmin The residual value mapped to the lowest color. If :obj:`None`, defaults to zero. vmax The residual value mapped to the highest color. If :obj:`None`, defaults to the maximum residual. kwargs Additional keyword arguments passed to :func:`named_arrays.plt.pcolormesh`. """ scene = self.coordinates_scene position = scene.position wavelength = na.as_named_array(scene.wavelength) axis_wavelength = self.axis_wavelength residual = (self.coordinates_sensor - self.fit.predictions).length unit = na.unit(residual) if vmin is None: vmin = 0 * unit if vmax is None: vmax = residual.max() ncols = na.shape(wavelength).get(axis_wavelength, 1) if figsize is None: # shape each subplot to the field-of-view aspect ratio, and widen # the figure to fit one subplot per wavelength height_subplot = 3 aspect = (position.x.ptp() / position.y.ptp()).ndarray.value figsize = ( ncols * height_subplot * aspect + 1.5, height_subplot + 1, ) with astropy.visualization.quantity_support(): fig, ax = na.plt.subplots( axis_cols=axis_wavelength, ncols=ncols, sharex=True, sharey=True, squeeze=False, figsize=figsize, constrained_layout=True, ) colorizer = plt.Colorizer( cmap=cmap, norm=plt.Normalize( vmin=na.as_named_array(vmin).ndarray.to_value(unit), vmax=na.as_named_array(vmax).ndarray.to_value(unit), ), ) na.plt.pcolormesh( position, C=residual, ax=ax, colorizer=colorizer, **kwargs, ) na.plt.set_xlabel(f"field $x$ ({na.unit(position.x):latex_inline})", ax=ax) na.plt.set_ylabel( f"field $y$ ({na.unit(position.y):latex_inline})", ax=ax[{axis_wavelength: 0}], ) na.plt.set_title(wavelength.to_string_array(), ax=ax) plt.colorbar( mappable=matplotlib.cm.ScalarMappable(colorizer=colorizer), ax=ax.ndarray, label=f"residual ({unit:latex_inline})", ) return fig, ax