Source code for pyvisgen.simulation.scan

from math import pi

import torch
from scipy.constants import c
from torch.special import bessel_j1

torch.set_default_dtype(torch.float64)

__all__ = [
    "rime",
    "calc_fourier",
    "calc_feed_rotation",
    "calc_beam",
    "angular_distance",
    "jinc",
    "integrate",
]


[docs] @torch.compile def rime( img, bas, lm, rd, ra, dec, ant_diam, spw_low, spw_high, polarization, mode, corrupted=False, ): """Calculates visibilities using RIME Parameters ---------- img: torch.tensor sky distribution bas : dataclass object baselines dataclass lm : 2d array lm grid for FOV spw_low : float lower wavelength spw_high : float higher wavelength polarization : str Type of polarization. Returns ------- 2d tensor Returns visibility for every baseline """ with torch.no_grad(): X1, X2 = calc_fourier(img, bas, lm, spw_low, spw_high) if polarization and mode != "dense": X1, X2 = calc_feed_rotation(X1, X2, bas, polarization) if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) vis = integrate(X1, X2) return vis
[docs] @torch.compile def calc_fourier( img: torch.tensor, bas, lm: torch.tensor, spw_low: float, spw_high: float, ) -> tuple[torch.tensor, torch.tensor]: """Calculates Fourier transformation kernel for every baseline and pixel in the lm grid. Parameters ---------- img : :func:`~torch.tensor` Sky distribution. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` :class:`~pyvisgen.simulation.Baselines` dataclass object containing information on u, v, and w coverage, and observation times. lm : :func:`~torch.tensor` lm grid for FOV. spw_low : float Lower wavelength. spw_high : float Higher wavelength. Returns ------- tuple[torch.tensor, torch.tensor] Fourier kernels for every pixel in the lm grid and given baselines. Shape is given by lm axes and baseline axis. """ # only use u, v, w valid u_cmplt = bas[2] v_cmplt = bas[5] w_cmplt = bas[8] l = lm[..., 0] # noqa: E741 m = lm[..., 1] n = torch.sqrt(1 - l**2 - m**2) ul = u_cmplt[..., None] * l vm = v_cmplt[..., None] * m wn = w_cmplt[..., None] * (n - 1) del l, m, n, u_cmplt, v_cmplt, w_cmplt K1 = torch.exp(-2 * pi * 1j * (ul + vm + wn) / c * spw_low)[..., None, None] K2 = torch.exp(-2 * pi * 1j * (ul + vm + wn) / c * spw_high)[..., None, None] del ul, vm, wn return img * K1, img * K2
[docs] @torch.compile def calc_feed_rotation( X1: torch.tensor, X2: torch.tensor, bas, polarization: str, ) -> tuple[torch.tensor, torch.tensor]: """Calculates the feed rotation due to the parallactic angle rotation of the source over time. Parameters ---------- X1 : :func:`~torch.tensor` Fourier kernel calculated via :func:`~pyvisgen.simulation.calc_fourier`. X2 : :func:`~torch.tensor` Fourier kernel calculated via :func:`~pyvisgen.simulation.calc_fourier`. bas : :class:`~pyvisgen.simulation.ValidBaselineSubset` :class:`~pyvisgen.simulation.Baselines` dataclass object containing information on u, v, and w coverage, observation times, and parallactic angles. polarization : str Type of polarization for the feed. Returns ------- X1 : :func:`~torch.tensor` Fourier kernel with the applied feed rotation. X2 : :func:`~torch.tensor` Fourier kernel with the applied feed rotation. """ q1 = bas[13][..., None] q2 = bas[16][..., None] xa = torch.zeros_like(X1) xb = torch.zeros_like(X2) if polarization == "circular": xa[..., 0, 0] = X1[..., 0, 0] * torch.exp(1j * q1) xa[..., 0, 1] = X1[..., 0, 1] * torch.exp(-1j * q1) xa[..., 1, 0] = X1[..., 1, 0] * torch.exp(1j * q1) xa[..., 1, 1] = X1[..., 1, 1] * torch.exp(-1j * q1) xb[..., 0, 0] = X2[..., 0, 0] * torch.exp(1j * q2) xb[..., 0, 1] = X2[..., 0, 1] * torch.exp(-1j * q2) xb[..., 1, 0] = X2[..., 1, 0] * torch.exp(1j * q2) xb[..., 1, 1] = X2[..., 1, 1] * torch.exp(-1j * q2) else: xa[..., 0, 0] = X1[..., 0, 0] * torch.cos(q1) - X1[..., 0, 1] * torch.sin(q1) xa[..., 0, 1] = X1[..., 0, 0] * torch.sin(q1) + X1[..., 0, 1] * torch.cos(q1) xa[..., 1, 0] = X1[..., 1, 0] * torch.cos(q1) - X1[..., 1, 1] * torch.sin(q1) xa[..., 1, 1] = X1[..., 1, 0] * torch.sin(q1) + X1[..., 1, 1] * torch.cos(q1) xb[..., 0, 0] = X2[..., 0, 0] * torch.cos(q2) - X2[..., 0, 1] * torch.sin(q2) xb[..., 0, 1] = X2[..., 0, 0] * torch.sin(q2) + X2[..., 0, 1] * torch.cos(q2) xb[..., 1, 0] = X2[..., 1, 0] * torch.cos(q2) - X2[..., 1, 1] * torch.sin(q2) xb[..., 1, 1] = X2[..., 1, 0] * torch.sin(q2) + X2[..., 1, 1] * torch.cos(q2) X1 = xa.detach().clone() X2 = xb.detach().clone() del xa, xb return X1, X2
[docs] @torch.compile def calc_beam( X1: torch.tensor, X2: torch.tensor, rd: torch.tensor, ra: float, dec: float, ant_diam: torch.tensor, spw_low: float, spw_high: float, ) -> tuple[torch.tensor, torch.tensor]: """Computes the beam influence on the image. Parameters ---------- X1 : :func:`~torch.tensor` X2 : :func:`~torch.tensor` rd : :func:`~torch.tensor` ra : float dec : float ant_diam : :func:`~torch.tensor` spw_low : float spw_high : float Returns ------- tuple[torch.tensor, torch.tensor] """ diameters = ant_diam.to(rd.device) theta = angular_distance(rd, ra, dec) tds = diameters * theta[..., None] E1 = jinc(2 * pi / c * spw_low * tds) E2 = jinc(2 * pi / c * spw_high * tds) assert E1.shape == E2.shape EXE1 = E1[..., None] * X1 * E1[..., None] del E1, X1 EXE2 = E2[..., None] * X2 * E2[..., None] del E2, X2 return EXE1, EXE2
[docs] @torch.compile def angular_distance(rd, ra, dec): """Calculates angular distance from source position Parameters ---------- rd : 3d tensor every pixel containing ra and dec ra : float right ascension of source position dec : float declination of source position Returns ------- 2d array Returns angular Distance for every pixel in rd grid with respect to source position """ r = rd[..., 0] d = rd[..., 1] - torch.deg2rad(dec.to(rd.device)) theta = torch.arcsin(torch.sqrt(r**2 + d**2)) return theta
[docs] @torch.compile def jinc(x): """Create jinc function. Parameters ---------- x : array value of (?) Returns ------- array value of jinc function at x """ jinc = torch.ones(x.shape, device=x.device).double() jinc[x != 0] = 2 * bessel_j1(x[x != 0]) / x[x != 0] return jinc
[docs] @torch.compile def integrate(X1, X2): """Summation over (l,m) and avering over time and freq Parameters ---------- X1 : 3d tensor visibility for every (l,m) and baseline for freq1 X2 : 3d tensor visibility for every (l,m) and baseline for freq2 Returns ------- 2d tensor Returns visibility for every baseline """ X_f = torch.stack((X1, X2)) # sum over all sky pixels # only integrate for 1 sky dimension # 2d sky is reshaped to 1d by sensitivity mask int_lm = torch.sum(X_f, dim=2) del X_f # average two bandwidth edges int_f = 0.5 * torch.sum(int_lm, dim=0) del int_lm return int_f