Source code for pyvisgen.simulation.scan

from __future__ import annotations

from math import pi
from typing import TYPE_CHECKING

import torch
from radioft.finufft import CupyFinufft
from scipy.constants import c
from torch.special import bessel_j1

if TYPE_CHECKING:
    from typing import Literal

    from numpy.typing import ArrayLike

torch.set_default_dtype(torch.float64)


finufft = CupyFinufft(image_size=512, fov_arcsec=1024, eps=1e-8)

__all__ = [
    "rime",
    "apply_finufft",
    "calc_fourier",
    "calc_feed_rotation",
    "calc_beam",
    "angular_distance",
    "jinc",
    "integrate",
]


# @torch.compile
[docs] def rime( img: ArrayLike, bas: ArrayLike, lm: ArrayLike, rd: ArrayLike, ra: ArrayLike, dec: ArrayLike, ant_diam: ArrayLike, spw_low: ArrayLike, spw_high: ArrayLike, polarization: str, mode: str, corrupted: bool = False, ft: Literal["default", "finufft", "reversed"] = "default", ): """Calculates visibilities using RIME Parameters ---------- img: torch.tensor sky distribution bas : dataclass object baselines dataclass lm : 2d array lm grid for FOV spw_low : float lower wavelength spw_high : float higher wavelength polarization : str Type of polarization. mode : str Select one of `'full'`, `'grid'`, or `'dense'` to get all valid baselines, a grid of unique baselines, or dense baselines. corrupted : bool, optional If ``True``, apply beam smearing to the simulated data. Default: ``False`` ft : str, optional Sets the type of fourier transform used in the RIME. Choose one of ``'default'``, ``'finufft'`` (Flatiron Institute Nonuniform Fast Fourier Transform) or `'reversed'`. Default: ``'default'`` Returns ------- 2d tensor Returns visibility for every baseline """ if ft == "default": with torch.no_grad(): X1 = img.clone() X2 = img.clone() X1, X2 = calc_fourier(X1, X2, bas, lm, spw_low, spw_high) if polarization and mode != "dense": X1, X2 = calc_feed_rotation(X1, X2, bas, polarization) if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) vis = integrate(X1, X2) if ft == "reversed": with torch.no_grad(): img = torch.repeat_interleave(img.clone()[None], len(bas[2]), dim=0) X1 = img.clone() X2 = img.clone() if polarization and mode != "dense": X1, X2 = calc_feed_rotation(X1, X2, bas, polarization) if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) X1, X2 = calc_fourier(X1, X2, bas, lm, spw_low, spw_high) vis = integrate(X1, X2) if ft == "finufft": with torch.no_grad(): X1 = img.clone() X2 = img.clone() if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) vis = apply_finufft(X1, X2, bas, lm, spw_low, spw_high) return vis
[docs] def apply_finufft( X1: torch.tensor, X2: torch.tensor, bas, lm: torch.tensor, spw_low: float, spw_high: float, ) -> tuple[torch.tensor, torch.tensor]: # pragma: no cover if not torch.cuda.is_available(): raise RuntimeError( "CUDA is not available. Finufft backend requires a CUDA-enabled GPU to run." ) l_coords = lm[..., 0] m_coords = lm[..., 1] n_coords = torch.sqrt(1 - l_coords**2 - m_coords**2) u_coords_low = bas[2] / c * spw_low v_coords_low = bas[5] / c * spw_low w_coords_low = bas[8] / c * spw_low u_coords_high = bas[2] / c * spw_high v_coords_high = bas[5] / c * spw_high w_coords_high = bas[8] / c * spw_high n_baselines = len(bas[2]) # Pre-allocate output vis = torch.empty([n_baselines, 2, 2], dtype=torch.complex128, device=X1.device) # Reshape input X1_flat = X1.reshape(4, -1) X2_flat = X2.reshape(4, -1) # Create CUDA streams for parallel execution of the 4 Stokes params streams = [torch.cuda.Stream() for _ in range(4)] results_low = [] results_high = [] for i in range(4): with torch.cuda.stream(streams[i]): vis_low = finufft.nufft( X1_flat[i], l_coords, m_coords, n_coords, u_coords_low, v_coords_low, w_coords_low, ) vis_high = finufft.nufft( X2_flat[i], l_coords, m_coords, n_coords, u_coords_high, v_coords_high, w_coords_high, ) results_low.append(vis_low) results_high.append(vis_high) # Synchronize all streams torch.cuda.synchronize() # Stack and reshape vis_low_all = torch.stack(results_low) vis_high_all = torch.stack(results_high) vis_avg = (vis_low_all + vis_high_all) / 2 vis = vis_avg.T.reshape(n_baselines, 2, 2) return vis
[docs] def calc_fourier( X1: torch.tensor, X2: torch.tensor, bas, lm: torch.tensor, spw_low: float, spw_high: float, ) -> tuple[torch.tensor, torch.tensor]: """Calculates Fourier transformation kernel for every baseline and pixel in the lm grid. Parameters ---------- X1 : :func:`~torch.tensor` Sky tensor. X2 : :func:`~torch.tensor` Sky tensor. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` :class:`~pyvisgen.simulation.Baselines` dataclass object containing information on u, v, and w coverage, and observation times. lm : :func:`~torch.tensor` lm grid for FOV. spw_low : float Lower wavelength. spw_high : float Higher wavelength. Returns ------- tuple[torch.tensor, torch.tensor] Fourier kernels for every pixel in the lm grid and given baselines. Shape is given by lm axes and baseline axis. """ # only use u, v, w valid u_valid = bas[2] v_valid = bas[5] w_valid = bas[8] l = lm[..., 0] # noqa: E741 m = lm[..., 1] n = torch.sqrt(1 - l**2 - m**2) ul = u_valid[..., None] * l vm = v_valid[..., None] * m wn = w_valid[..., None] * (n - 1) del l, m, n, u_valid, v_valid, w_valid K1 = torch.exp(-2 * pi * 1j * (ul + vm + wn) / c * spw_low)[..., None, None] K2 = torch.exp(-2 * pi * 1j * (ul + vm + wn) / c * spw_high)[..., None, None] del ul, vm, wn return X1 * K1, X2 * K2
[docs] def calc_feed_rotation( X1: torch.tensor, X2: torch.tensor, bas, polarization: str, ) -> tuple[torch.tensor, torch.tensor]: """Calculates the feed rotation due to the parallactic angle rotation of the source over time. Parameters ---------- X1 : :func:`~torch.tensor` Fourier kernel calculated via :func:`~pyvisgen.simulation.calc_fourier`. X2 : :func:`~torch.tensor` Fourier kernel calculated via :func:`~pyvisgen.simulation.calc_fourier`. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` :class:`~pyvisgen.simulation.Baselines` dataclass object containing information on u, v, and w coverage, observation times, and parallactic angles. polarization : str Type of polarization for the feed. Returns ------- X1 : :func:`~torch.tensor` Fourier kernel with the applied feed rotation. X2 : :func:`~torch.tensor` Fourier kernel with the applied feed rotation. """ q1 = bas[13][..., None] q2 = bas[16][..., None] xa = torch.zeros_like(X1) xb = torch.zeros_like(X2) if polarization == "circular": xa[..., 0, 0] = X1[..., 0, 0] * torch.exp(1j * q1) xa[..., 0, 1] = X1[..., 0, 1] * torch.exp(-1j * q1) xa[..., 1, 0] = X1[..., 1, 0] * torch.exp(1j * q1) xa[..., 1, 1] = X1[..., 1, 1] * torch.exp(-1j * q1) xb[..., 0, 0] = X2[..., 0, 0] * torch.exp(1j * q2) xb[..., 0, 1] = X2[..., 0, 1] * torch.exp(-1j * q2) xb[..., 1, 0] = X2[..., 1, 0] * torch.exp(1j * q2) xb[..., 1, 1] = X2[..., 1, 1] * torch.exp(-1j * q2) else: xa[..., 0, 0] = X1[..., 0, 0] * torch.cos(q1) - X1[..., 0, 1] * torch.sin(q1) xa[..., 0, 1] = X1[..., 0, 0] * torch.sin(q1) + X1[..., 0, 1] * torch.cos(q1) xa[..., 1, 0] = X1[..., 1, 0] * torch.cos(q1) - X1[..., 1, 1] * torch.sin(q1) xa[..., 1, 1] = X1[..., 1, 0] * torch.sin(q1) + X1[..., 1, 1] * torch.cos(q1) xb[..., 0, 0] = X2[..., 0, 0] * torch.cos(q2) - X2[..., 0, 1] * torch.sin(q2) xb[..., 0, 1] = X2[..., 0, 0] * torch.sin(q2) + X2[..., 0, 1] * torch.cos(q2) xb[..., 1, 0] = X2[..., 1, 0] * torch.cos(q2) - X2[..., 1, 1] * torch.sin(q2) xb[..., 1, 1] = X2[..., 1, 0] * torch.sin(q2) + X2[..., 1, 1] * torch.cos(q2) X1 = xa.detach().clone() X2 = xb.detach().clone() del xa, xb return X1, X2
[docs] def calc_beam( X1: torch.tensor, X2: torch.tensor, rd: torch.tensor, ra: float, dec: float, ant_diam: torch.tensor, spw_low: float, spw_high: float, ) -> tuple[torch.tensor, torch.tensor]: """Computes the beam influence on the image. Parameters ---------- X1 : :func:`~torch.tensor` X2 : :func:`~torch.tensor` rd : :func:`~torch.tensor` ra : float dec : float ant_diam : :func:`~torch.tensor` spw_low : float spw_high : float Returns ------- tuple[torch.tensor, torch.tensor] """ diameters = ant_diam.to(rd.device) theta = angular_distance(rd, ra, dec) tds = diameters * theta[..., None] E1 = jinc(2 * pi / c * spw_low * tds) E2 = jinc(2 * pi / c * spw_high * tds) assert E1.shape == E2.shape EXE1 = E1[..., None] * X1 * E1[..., None] del E1, X1 EXE2 = E2[..., None] * X2 * E2[..., None] del E2, X2 return EXE1, EXE2
[docs] @torch.compile def angular_distance(rd, ra, dec): """Calculates angular distance from source position Parameters ---------- rd : 3d tensor every pixel containing ra and dec ra : float right ascension of source position dec : float declination of source position Returns ------- 2d array Returns angular Distance for every pixel in rd grid with respect to source position """ r = rd[..., 0] d = rd[..., 1] - torch.deg2rad(dec.to(rd.device)) theta = torch.arcsin(torch.sqrt(r**2 + d**2)) return theta
[docs] @torch.compile def jinc(x): """Create jinc function. Parameters ---------- x : array value of (?) Returns ------- array value of jinc function at x """ jinc = torch.ones(x.shape, device=x.device).double() jinc[x != 0] = 2 * bessel_j1(x[x != 0]) / x[x != 0] return jinc
[docs] @torch.compile def integrate(X1, X2): """Summation over (l,m) and avering over time and freq Parameters ---------- X1 : 3d tensor visibility for every (l,m) and baseline for freq1 X2 : 3d tensor visibility for every (l,m) and baseline for freq2 Returns ------- 2d tensor Returns visibility for every baseline """ X_f = torch.stack((X1, X2)) # sum over all sky pixels # only integrate for 1 sky dimension # 2d sky is reshaped to 1d by sensitivity mask int_lm = torch.sum(X_f, dim=2) del X_f # average two bandwidth edges int_f = 0.5 * torch.sum(int_lm, dim=0) del int_lm return int_f