Source code for pyvisgen.simulation.scan

from math import pi

import torch
from scipy.constants import c
from torch.special import bessel_j1

torch.set_default_dtype(torch.float64)

try:
    from radioft.finufft import CupyFinufft

    _FINUFFT_AVAIL = True

except ImportError as e:
    _FINUFFT_AVAIL = False
    _FINUFFT_ERROR = str(e)


__all__ = [
    "RIMEScan",
    "apply_finufft",
    "calc_fourier",
    "calc_feed_rotation",
    "calc_beam",
    "angular_distance",
    "jinc",
    "integrate",
]


[docs] class RIMEScan: """Apply the Radio Interferometry Measurement Equation (RIME) to sky images. This class handles the calculation of visibilities from sky images using either direct Fourier transforms or Non-Uniform Fast Fourier Transforms (FINUFFT). Depending on the observation settings it also applies telescope beam effects or feed rotation. Parameters ---------- ft : str Fourier transform method to use ('default', 'reversed', or 'finufft'). mode : str Observation mode, can be 'full', 'grid' or 'dense'. Dense is only suitable for debugging and disables feed rotation calculation. obs : :class:`~pyvisgen.simulation.Observation` Observation class object containing observation parameters such as the source position. lm : :class:`~torch.Tensor` Direction cosines (l, m) grid for the field of view. rd : :class:`~torch.Tensor` Right ascension and declination grid corresponding to the field of view. eps : float, optional Tolerance for the cuFINUFFT. Default: 1e-8. """ def __init__(self, ft, mode, obs, lm, rd, eps=1e-8): if _FINUFFT_AVAIL: self.cupy_finufft = CupyFinufft( image_size=obs.img_size, fov_arcsec=obs.fov, eps=eps ) self.mode = mode self.ft = ft self.ft_func = getattr(self, ft) self.polarization = obs.polarization self.corrupted = obs.corrupted self.ra = obs.ra self.dec = obs.dec self.ant_diam = torch.unique(obs.array.diam) self.lm = lm self.rd = rd
[docs] def __call__( self, img: torch.Tensor, bas, spw_low: torch.Tensor, spw_high: torch.Tensor, ) -> torch.Tensor: """Process the input sky image to produce visibilities. Parameters ---------- img : :class:`~torch.Tensor` Input sky image tensor. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` Subset of valid baselines containing uvw coordinates. spw_low : :class:`~torch.Tensor` Lower spectral window frequencies/wavelengths. spw_high : :class:`~torch.Tensor` Higher spectral window frequencies/wavelengths. Returns ------- :class:`~torch.Tensor` Calculated complex visibilities for the given baselines. """ with torch.no_grad(): if self.ft == "reversed": img = torch.repeat_interleave( img.clone()[None], len(bas.u_valid), dim=0 ) X1 = img.clone() X2 = img.clone() return self.ft_func( X1, X2, bas, spw_low, spw_high, )
[docs] def default( self, X1: torch.Tensor, X2: torch.Tensor, bas: torch.Tensor, spw_low: torch.Tensor, spw_high: torch.Tensor, ) -> torch.Tensor: """Evaluate default the RIME Jones chain. Calculates direct Fourier kernels, applies feed rotation and beam effects, and integrates over the image. Parameters ---------- X1 : :class:`~torch.Tensor` Sky tensor for the lower spectral window. X2 : :class:`~torch.Tensor` Sky tensor for the higher spectral window. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` Subset of valid baselines containing uvw coordinates. spw_low : :class:`~torch.Tensor` Lower spectral window frequencies/wavelengths. spw_high : :class:`~torch.Tensor` Higher spectral window frequencies/wavelengths. Returns ------- :class:`~torch.Tensor` Visibilities calculated with the default RIME Jones chain. """ X1, X2 = calc_fourier(X1, X2, bas, self.lm, spw_low, spw_high) if self.polarization and self.mode != "dense": X1, X2 = calc_feed_rotation(X1, X2, bas, self.polarization) if self.corrupted: X1, X2 = calc_beam( X1, X2, self.rd, self.ra, self.dec, self.ant_diam, spw_low, spw_high, ) vis = integrate(X1, X2) return vis
[docs] def reversed( self, X1: torch.Tensor, X2: torch.Tensor, bas: torch.Tensor, spw_low: torch.Tensor, spw_high: torch.Tensor, ) -> torch.Tensor: """Compute the default RIME Jones chain in reversed order. Applies feed rotation and primary beam effects before calculating Fourier kernels and integrating. Parameters ---------- X1 : :class:`~torch.Tensor` Sky tensor for the lower spectral window. X2 : :class:`~torch.Tensor` Sky tensor for the higher spectral window. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` Subset of valid baselines containing uvw coordinates. spw_low : :class:`~torch.Tensor` Lower spectral window frequencies/wavelengths. spw_high : :class:`~torch.Tensor` Higher spectral window frequencies/wavelengths. Returns ------- :class:`~torch.Tensor` Visibilities calculated with the reversed RIME Jones chain. """ if self.polarization and self.mode != "dense": X1, X2 = calc_feed_rotation(X1, X2, bas, self.polarization) if self.corrupted: X1, X2 = calc_beam( X1, X2, self.rd, self.ra, self.dec, self.ant_diam, spw_low, spw_high, ) X1, X2 = calc_fourier(X1, X2, bas, self.lm, spw_low, spw_high) vis = integrate(X1, X2) return vis
[docs] def finufft( self, X1: torch.Tensor, X2: torch.Tensor, bas: torch.Tensor, spw_low: torch.Tensor, spw_high: torch.Tensor, ) -> torch.Tensor: """Evaluate RIME using cuFINUFFT. Utilizes GPU-accelerated non-uniform FFTs for faster visibility computation. Parameters ---------- X1 : :class:`~torch.Tensor` Sky tensor for the lower spectral window. X2 : :class:`~torch.Tensor` Sky tensor for the higher spectral window. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` Subset of valid baselines containing uvw coordinates. spw_low : :class:`~torch.Tensor` Lower spectral window frequencies/wavelengths. spw_high : :class:`~torch.Tensor` Higher spectral window frequencies/wavelengths. Returns ------- :class:`~torch.Tensor` Visibilities computed with cuFINUFFT. Raises ------ RuntimeError If cuFINUFFT package is not successfully loaded or available. """ if not _FINUFFT_AVAIL: raise RuntimeError(_FINUFFT_ERROR) if self.corrupted: X1, X2 = calc_beam( X1, X2, self.rd, self.ra, self.dec, self.ant_diam, spw_low, spw_high, ) vis = apply_finufft( X1, X2, bas, self.lm, spw_low, spw_high, finufft=self.cupy_finufft ) return vis
[docs] def apply_finufft( X1: torch.Tensor, X2: torch.Tensor, bas, lm: torch.Tensor, spw_low: float | torch.Tensor, spw_high: float | torch.Tensor, finufft, ) -> torch.Tensor: # pragma: no cover """Apply cuFINUFFT to input images to compute visibilities. Parameters ---------- X1 : :class:`~torch.Tensor` Sky tensor for the lower spectral window. X2 : :class:`~torch.Tensor` Sky tensor for the higher spectral window. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` Subset of valid baselines containing uvw coordinates. spw_low : :class:`~torch.Tensor` Lower spectral window frequencies/wavelengths. spw_high : :class:`~torch.Tensor` Higher spectral window frequencies/wavelengths. finufft : :class:`~radioft.finufft.finufft.CupyFinufft` Initialized :class:`~radioft.finufft.finufft.CupyFinufft` object to be used to compute the visibilities. Returns ------- :class:`~torch.Tensor` Visibilities computed with cuFINUFFT. Raises ------ RuntimeError If CUDA is not available. """ if not torch.cuda.is_available(): raise RuntimeError( "CUDA is not available. Finufft backend requires a CUDA-enabled GPU to run." ) l_coords = lm[..., 0] m_coords = lm[..., 1] n_coords = torch.sqrt(1 - l_coords**2 - m_coords**2) u_coords_low = bas.u_valid / c * spw_low v_coords_low = bas.v_valid / c * spw_low w_coords_low = bas.w_valid / c * spw_low u_coords_high = bas.u_valid / c * spw_high v_coords_high = bas.v_valid / c * spw_high w_coords_high = bas.w_valid / c * spw_high # Reshape input X1_flat = X1.permute(1, 2, 0).reshape(4, -1) X2_flat = X2.permute(1, 2, 0).reshape(4, -1) results_low = [] results_high = [] for i in range(4): vis_low = finufft.nufft( X1_flat[i], l_coords, m_coords, n_coords, u_coords_low, v_coords_low, w_coords_low, ) vis_high = finufft.nufft( X2_flat[i], l_coords, m_coords, n_coords, u_coords_high, v_coords_high, w_coords_high, ) results_low.append(vis_low) results_high.append(vis_high) # Stack and reshape vis_low_all = torch.stack(results_low) vis_high_all = torch.stack(results_high) vis_avg = (vis_low_all + vis_high_all) / 2 vis = vis_avg.mT.reshape(-1, 2, 2) return vis
[docs] def calc_fourier( X1: torch.Tensor, X2: torch.Tensor, bas, lm: torch.Tensor, spw_low: float | torch.Tensor, spw_high: float | torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates Fourier transformation kernel for every baseline and pixel in the lm grid. Parameters ---------- X1 : :func:`~torch.tensor` Sky tensor. X2 : :func:`~torch.tensor` Sky tensor. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` :class:`~pyvisgen.simulation.Baselines` dataclass object containing information on u, v, and w coverage, and observation times. lm : :func:`~torch.tensor` lm grid for FOV. spw_low : float Lower wavelength. spw_high : float Higher wavelength. Returns ------- tuple[torch.tensor, torch.tensor] Fourier kernels for every pixel in the lm grid and given baselines. Shape is given by lm axes and baseline axis. """ # only use u, v, w valid u_valid = bas.u_valid v_valid = bas.v_valid w_valid = bas.w_valid l = lm[..., 0] # noqa: E741 m = lm[..., 1] n = torch.sqrt(1 - l**2 - m**2) ul = u_valid[..., None] * l vm = v_valid[..., None] * m wn = w_valid[..., None] * (n - 1) del l, m, n, u_valid, v_valid, w_valid K1 = torch.exp(-2 * pi * 1j * (ul + vm + wn) / c * spw_low)[..., None, None] K2 = torch.exp(-2 * pi * 1j * (ul + vm + wn) / c * spw_high)[..., None, None] del ul, vm, wn return X1 * K1, X2 * K2
[docs] def calc_feed_rotation( X1: torch.Tensor, X2: torch.Tensor, bas, polarization: str, ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the feed rotation due to the parallactic angle rotation of the source over time. Parameters ---------- X1 : :func:`~torch.tensor` Fourier kernel calculated via :func:`~pyvisgen.simulation.calc_fourier`. X2 : :func:`~torch.tensor` Fourier kernel calculated via :func:`~pyvisgen.simulation.calc_fourier`. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` :class:`~pyvisgen.simulation.Baselines` dataclass object containing information on u, v, and w coverage, observation times, and parallactic angles. polarization : str Type of polarization for the feed. Returns ------- X1 : :func:`~torch.tensor` Fourier kernel with the applied feed rotation. X2 : :func:`~torch.tensor` Fourier kernel with the applied feed rotation. """ q1 = bas.q1_valid[..., None] q2 = bas.q2_valid[..., None] xa = torch.zeros_like(X1) xb = torch.zeros_like(X2) if polarization == "circular": xa[..., 0, 0] = X1[..., 0, 0] * torch.exp(1j * q1) xa[..., 0, 1] = X1[..., 0, 1] * torch.exp(-1j * q1) xa[..., 1, 0] = X1[..., 1, 0] * torch.exp(1j * q1) xa[..., 1, 1] = X1[..., 1, 1] * torch.exp(-1j * q1) xb[..., 0, 0] = X2[..., 0, 0] * torch.exp(1j * q2) xb[..., 0, 1] = X2[..., 0, 1] * torch.exp(-1j * q2) xb[..., 1, 0] = X2[..., 1, 0] * torch.exp(1j * q2) xb[..., 1, 1] = X2[..., 1, 1] * torch.exp(-1j * q2) else: xa[..., 0, 0] = X1[..., 0, 0] * torch.cos(q1) - X1[..., 0, 1] * torch.sin(q1) xa[..., 0, 1] = X1[..., 0, 0] * torch.sin(q1) + X1[..., 0, 1] * torch.cos(q1) xa[..., 1, 0] = X1[..., 1, 0] * torch.cos(q1) - X1[..., 1, 1] * torch.sin(q1) xa[..., 1, 1] = X1[..., 1, 0] * torch.sin(q1) + X1[..., 1, 1] * torch.cos(q1) xb[..., 0, 0] = X2[..., 0, 0] * torch.cos(q2) - X2[..., 0, 1] * torch.sin(q2) xb[..., 0, 1] = X2[..., 0, 0] * torch.sin(q2) + X2[..., 0, 1] * torch.cos(q2) xb[..., 1, 0] = X2[..., 1, 0] * torch.cos(q2) - X2[..., 1, 1] * torch.sin(q2) xb[..., 1, 1] = X2[..., 1, 0] * torch.sin(q2) + X2[..., 1, 1] * torch.cos(q2) X1 = xa.detach().clone() X2 = xb.detach().clone() del xa, xb return X1, X2
[docs] def calc_beam( X1: torch.Tensor, X2: torch.Tensor, rd: torch.Tensor, ra: float | torch.Tensor, dec: float | torch.Tensor, ant_diam: torch.Tensor, spw_low: float | torch.Tensor, spw_high: float | torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Computes the beam influence on the image. Parameters ---------- X1 : :func:`~torch.tensor` X2 : :func:`~torch.tensor` rd : :func:`~torch.tensor` ra : float dec : float ant_diam : :func:`~torch.tensor` spw_low : float spw_high : float Returns ------- tuple[torch.tensor, torch.tensor] """ diameters = ant_diam.to(rd.device) theta = angular_distance(rd, ra, dec) tds = diameters * theta[..., None] E1 = jinc(2 * pi / c * spw_low * tds) E2 = jinc(2 * pi / c * spw_high * tds) assert E1.shape == E2.shape EXE1 = E1[..., None] * X1 * E1[..., None] del E1, X1 EXE2 = E2[..., None] * X2 * E2[..., None] del E2, X2 return EXE1, EXE2
[docs] @torch.compile def angular_distance(rd: torch.Tensor, ra: torch.Tensor, dec: torch.Tensor): """Calculates angular distance from source position Parameters ---------- rd : 3d tensor every pixel containing ra and dec ra : float right ascension of source position dec : float declination of source position Returns ------- 2d array Returns angular Distance for every pixel in rd grid with respect to source position """ r = rd[..., 0] d = rd[..., 1] - torch.deg2rad(dec.to(rd.device)) theta = torch.arcsin(torch.sqrt(r**2 + d**2)) return theta
[docs] @torch.compile def jinc(x): """Create jinc function. Parameters ---------- x : array value of (?) Returns ------- array value of jinc function at x """ jinc = torch.ones(x.shape, device=x.device).double() jinc[x != 0] = 2 * bessel_j1(x[x != 0]) / x[x != 0] return jinc
[docs] @torch.compile def integrate(X1, X2): """Summation over (l,m) and avering over time and freq Parameters ---------- X1 : 3d tensor visibility for every (l,m) and baseline for freq1 X2 : 3d tensor visibility for every (l,m) and baseline for freq2 Returns ------- 2d tensor Returns visibility for every baseline """ X_f = torch.stack((X1, X2)) # sum over all sky pixels # only integrate for 1 sky dimension # 2d sky is reshaped to 1d by sensitivity mask int_lm = torch.sum(X_f, dim=2) del X_f # average two bandwidth edges int_f = 0.5 * torch.sum(int_lm, dim=0) del int_lm return int_f