Source code for pyvisgen.simulation.visibility

from __future__ import annotations

from dataclasses import dataclass, fields
from typing import TYPE_CHECKING

import scipy.ndimage
import torch
from tqdm.auto import tqdm

from pyvisgen.simulation.noise import generate_noise
from pyvisgen.simulation.scan import RIMEScan
from pyvisgen.utils.batch_size import adaptive_batch_size
from pyvisgen.utils.logging import setup_logger

if TYPE_CHECKING:
    from typing import Literal

torch.set_default_dtype(torch.float64)
LOGGER = setup_logger(namespace=__name__)

__all__ = [
    "Visibilities",
    "vis_loop",
    "Polarization",
    "generate_noise",
]


[docs] @dataclass class Visibilities: """Visibilities dataclass. Attributes ---------- V_11 : :func:`~torch.tensor` V_22 : :func:`~torch.tensor` V_12 : :func:`~torch.tensor` V_21 : :func:`~torch.tensor` weights : :func:`~torch.tensor` num : :func:`~torch.tensor` base_num : :func:`~torch.tensor` u : :func:`~torch.tensor` v : :func:`~torch.tensor` w : :func:`~torch.tensor` date : :func:`~torch.tensor` linear_dop : :func:`~torch.tensor` circular_dop : :func:`~torch.tensor` """ V_11: torch.Tensor V_22: torch.Tensor V_12: torch.Tensor V_21: torch.Tensor weights: torch.Tensor num: torch.Tensor base_num: torch.Tensor u: torch.Tensor v: torch.Tensor w: torch.Tensor date: torch.Tensor st_id_pairs: torch.Tensor linear_dop: torch.Tensor circular_dop: torch.Tensor def __getitem__(self, i): return Visibilities(*[getattr(self, f.name)[i] for f in fields(self)])
[docs] def get_values(self): return torch.cat( [self.V_11[None], self.V_22[None], self.V_12[None], self.V_21[None]], dim=0 ).permute(1, 2, 0)
[docs] def add(self, visibilities): [ setattr( self, f.name, torch.cat([getattr(self, f.name), getattr(visibilities, f.name)]), ) for f in fields(self) ]
[docs] class Polarization: r"""Simulation of polarization. Creates the :math:`2\times 2` stokes matrix and simulates polarization if ``polarization`` is either ``'linear'`` or ``'circular'``. Also computes the degree of polarization. Parameters ---------- SI : :func:`~torch.tensor` Stokes I component, i.e. intensity distribution of the sky. sensitivity_cut : float Sensitivity cut, where only pixels above the value are kept. amp_ratio : float Sets the ratio of :math:`A_{X|R}`. The ratio of :math:`A_{Y|L}` is calculated as ``1 - amp_ratio``. If set to ``None``, a random value is drawn from a uniform distribution. See also: ``random_state``. delta : float Sets the phase difference of the amplitudes :math:`A_{X|R}` and :math:`A_{Y|L}`` of the sky distribution. Defines the measure of ellipticity. polarization : str Choose between ``'linear'`` or ``'circular'`` or ``None`` to simulate different types of polarizations or disable the simulation of polarization entirely. random_state : int Random state used when drawing ``amp_ratio`` and during the generation of the random polarization field. device : :class:`~torch.cuda.device` Torch device to select for computation. """ def __init__( self, SI: torch.Tensor, sensitivity_cut: float, amp_ratio: float, delta: float | torch.Tensor, polarization: str, field_kwargs: dict, random_state: int, device: torch.device, ) -> None: """Creates the :math:`2\times 2` stokes matrix and simulates polarization if ``polarization`` is either ``'linear'`` or ``'circular'``. Also computes the degree of polarization. Parameters ---------- SI : :func:`~torch.tensor` Stokes I component, i.e. intensity distribution of the sky. sensitivity_cut : float Sensitivity cut, where only pixels above the value are kept. amp_ratio : float Sets the ratio of :math:`A_{X|R}`. The ratio of :math:`A_{Y|L}` is calculated as ``1 - amp_ratio``. If set to ``None``, a random value is drawn from a uniform distribution. See also: ``random_state``. delta : float Sets the phase difference of the amplitudes :math:`A_{X|R}` and :math:`A_{Y|L}`` of the sky distribution. Defines the measure of ellipticity. polarization : str Choose between ``'linear'`` or ``'circular'`` or ``None`` to simulate different types of polarizations or disable the simulation of polarization entirely. random_state : int Random state used when drawing ``amp_ratio`` and during the generation of the random polarization field. device : :class:`~torch.cuda.device` Torch device to select for computation. """ self.sensitivity_cut = sensitivity_cut self.polarization = polarization self.device = device self.SI = SI.permute(dims=(1, 2, 0)) if random_state: torch.manual_seed(random_state) if self.polarization and self.polarization in ["circular", "linear"]: self.polarization_field = self.rand_polarization_field( [self.SI.shape[0], self.SI.shape[1]], **field_kwargs, ) if isinstance(delta, (float, int)): delta = torch.tensor(delta) self.delta = delta ax2 = amp_ratio if amp_ratio and amp_ratio >= 0 else torch.rand(1) if isinstance(ax2, torch.Tensor): ax2 = ax2.to(self.device) ay2 = 1 - ax2 self.ax2 = self.SI[..., 0].detach().clone().to(self.device) * ax2 self.ay2 = self.SI[..., 0].detach().clone().to(self.device) * ay2 else: self.ax2 = self.SI[..., 0] self.ay2 = torch.zeros_like(self.ax2) self.I = torch.zeros( (self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble ) # noqa: E741
[docs] def linear(self) -> None: r"""Computes the stokes parameters I, Q, U, and V for linear polarization. This is done using the following equations: .. math:: I &= A_X^2 + A_Y^2 \\ Q &= A_X^2 - A_Y^2 \\ U &= 2A_X A_Y \cos\delta_{XY} \\ V &= -2A_X A_Y \sin\delta_{XY} """ self.I[..., 0] = self.ax2 + self.ay2 self.I[..., 1] = self.ax2 - self.ay2 self.I[..., 2] = ( 2 * torch.sqrt(self.ax2) * torch.sqrt(self.ay2) * torch.cos(torch.deg2rad(self.delta)) ) self.I[..., 3] = ( -2 * torch.sqrt(self.ax2) * torch.sqrt(self.ay2) * torch.sin(torch.deg2rad(self.delta)) )
[docs] def circular(self) -> None: r"""Computes the stokes parameters I, Q, U, and V for circular polarization. This is done using the following equations: .. math:: I &= A_R^2 + A_L^2 \\ Q &= 2A_R A_L \cos\delta_{RL} \\ U &= -2A_R A_L \sin\delta_{RL} \\ V &= A_R^2 - A_L^2 """ self.I[..., 0] = self.ax2 + self.ay2 self.I[..., 1] = ( 2 * torch.sqrt(self.ax2) * torch.sqrt(self.ay2) * torch.cos(torch.deg2rad(self.delta)) ) self.I[..., 2] = ( -2 * torch.sqrt(self.ax2) * torch.sqrt(self.ay2) * torch.sin(torch.deg2rad(self.delta)) ) self.I[..., 3] = self.ax2 - self.ay2
[docs] def dop(self) -> None: """Computes the degree of polarization for each pixel.""" mask = (self.ax2 + self.ay2) > 0 # apply polarization_field to Q, U, and V only self.I[..., 1] *= self.polarization_field self.I[..., 2] *= self.polarization_field self.I[..., 3] *= self.polarization_field dop_I = self.I[..., 0].real.detach().clone() dop_I[~mask] = float("nan") dop_Q = self.I[..., 1].real.detach().clone() dop_Q[~mask] = float("nan") dop_U = self.I[..., 2].real.detach().clone() dop_U[~mask] = float("nan") dop_V = self.I[..., 3].real.detach().clone() dop_V[~mask] = float("nan") self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I self.circ_dop = torch.abs(dop_V) / dop_I del dop_I, dop_Q, dop_U, dop_V
[docs] def stokes_matrix(self) -> tuple: """Computes and returns the 2 x 2 stokes matrix B. Returns ------- B : torch.tensor 2 x 2 stokes brightness matrix. Either for linear, circular or no polarization. mask : torch.tensor Mask of the sensitivity cut (Keep all px > sensitivity_cut). lin_dop : torch.tensor Degree of linear polarization of every pixel in the sky. circ_dop : torch.tensor Degree of circular polarization of every pixel in the sky. """ # define 2 x 2 Stokes matrix B = torch.zeros( (self.SI.shape[0], self.SI.shape[1], 2, 2), dtype=torch.cdouble ).to(torch.device(self.device)) if self.polarization == "linear": self.linear() self.dop() B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q elif self.polarization == "circular": self.circular() self.dop() B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] # I + V B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] # Q + iU B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] # Q - iU B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] # I - V else: # No polarization applied self.I[..., 0] = self.SI[..., 0] self.polarization_field = torch.ones_like(self.I[..., 0]) self.dop() B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q # calculations only for px > sensitivity cut mask = (self.sensitivity_cut <= self.SI)[..., 0] B = B[mask] return B, mask, self.lin_dop, self.circ_dop
[docs] def rand_polarization_field( self, shape: list[int] | int, order: list[int] | int = 1, random_state: int | None = None, scale: list | None = None, threshold: float | None = None, ) -> torch.Tensor: """ Generates a random noise mask for polarization. Parameters ---------- shape : array_like (M, N), or int The size of the sky image. order : array_like (M, N) or int, optional Morphology of the random noise. Higher values create more and smaller fluctuations. Default: ``1``. random_state : int, optional Random state for the random number generator. If ``None``, a random entropy is pulled from the OS. Default: ``None``. scale : array_like, optional Scaling of the distribution of the image. Default: ``[0, 1]`` threshold : float, optional If not None, an upper threshold is applied to the image. Default: ``None`` Returns ------- im : torch.tensor An array containing random noise values between scale[0] and scale[1]. """ if random_state: torch.manual_seed(random_state) if isinstance(shape, int): shape = [shape] if not isinstance(shape, list): shape = list(shape) if len(shape) < 2: shape *= 2 elif len(shape) > 2: raise ValueError("Expected len of 'shape' to be 2!") if isinstance(order, int | float): order = [order] if not isinstance(order, list): order = list(order) if len(order) < 2: order *= 2 elif len(order) > 2: raise ValueError("Expected len of 'order' to be 2!") sigma = torch.mean(torch.tensor(shape).double()) / (40 * torch.tensor(order)) im = torch.rand(shape) im = scipy.ndimage.gaussian_filter(im, sigma=sigma.numpy()) if scale is None: scale = [im.min(), im.max()] if len(scale) != 2: raise ValueError("Expected len of 'scale' to be 2!") im_flatten = torch.from_numpy(im.flatten()) im_argsort = torch.argsort(torch.argsort(im_flatten)) im_linspace = torch.linspace(*scale, im_argsort.size()[0]) uniform_flatten = im_linspace[im_argsort] im = torch.reshape(uniform_flatten, im.shape) if threshold: im = im[im < threshold] return im
[docs] def vis_loop( obs, SI: torch.Tensor, num_threads: int = 10, noise_level: float = 0, noise_mode: str = "sefd", telescope: str = "meerkat", band: str | None = None, mode: str = "full", batch_size: int | Literal["auto"] = "auto", show_progress: bool = False, normalize: bool = True, ft: Literal["default", "finufft", "reversed"] = "default", ) -> Visibilities: r"""Computes the visibilities of an observation. Parameters ---------- obs : Observation class object Observation class object generated by the `~pyvisgen.simulation.Observation` class. SI : torch.tensor Tensor containing the sky intensity distribution. num_threads : int, optional Number of threads used for intraoperative parallelism on the CPU. See `~torch.set_num_threads`. Default: 10 noise_level : float, optional Noise amplitude: SEFD in Jy when ``noise_mode='sefd'``, or T_sys/η in K when ``noise_mode='tsys'``. Set to 0 to disable noise. Default: 0 noise_mode : str, optional ``'sefd'``: uniform SEFD noise (backward compatible, no elevation dependence). ``'tsys'``: elevation-dependent noise from system temperature. Default: ``'sefd'`` telescope : str, optional Telescope name for elevation-dependent Tsys corrections. Only used when ``noise_mode='tsys'``. Default: ``'meerkat'`` mode : str, optional Select one of `'full'`, `'grid'`, or `'dense'` to get all valid baselines, a grid of unique baselines, or dense baselines. Default: 'full' batch_size : int, optional Batch size for iteration over baselines. Default: 100 polarization : str, optional Choose between `'linear'` or `'circular'` or `None` to simulate different types of polarizations or disable the simulation of polarization. Default: 'linear' random_state : int, optional Random state used when drawing `amp_ratio` and during the generation of the random polarization field. Default: 42 show_progress : bool, optional If `True`, show a progress bar during the iteration over the batches of baselines. Default: False normalize : bool, optional If ``True``, normalize stokes matrix ``B`` by a factor 0.5. Default: ``True`` ft : str, optional Sets the type of fourier transform used in the RIME. Choose one of ``'default'``, ``'finufft'`` (Flatiron Institute Nonuniform Fast Fourier Transform) or `'reversed'`. Default: ``'default'`` Returns ------- visibilities : Visibilities Dataclass object containing visibilities and baselines. """ torch.set_num_threads(num_threads) torch._dynamo.config.suppress_errors = True if not ( isinstance(batch_size, int) or (isinstance(batch_size, str) and batch_size == "auto") ): raise ValueError("Expected batch_size to be 'auto' or type int") pol = Polarization( SI, sensitivity_cut=obs.sensitivity_cut, polarization=obs.polarization, device=obs.device, field_kwargs=obs.field_kwargs, **obs.pol_kwargs, ) B, mask, lin_dop, circ_dop = pol.stokes_matrix() lm = obs.lm[mask] rd = obs.rd[mask] # normalize visibilities to factor 0.5, # so that the Stokes I image is normalized to 1 if normalize: B *= 0.5 # calculate vis visibilities = Visibilities( torch.empty(size=[0] + [len(obs.waves_low)]), torch.empty(size=[0] + [len(obs.waves_low)]), torch.empty(size=[0] + [len(obs.waves_low)]), torch.empty(size=[0] + [len(obs.waves_low)]), torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.empty(0, 2), torch.tensor([]), torch.tensor([]), ) vis_num = torch.zeros(1) if mode == "full": bas = obs.baselines.get_valid_subset(obs.num_baselines, obs.device) elif mode == "grid": bas = obs.baselines.get_valid_subset( obs.num_baselines, obs.device ).get_unique_grid(obs.fov, obs.ref_frequency, obs.img_size, obs.device) elif mode == "dense": if obs.device == torch.device("cpu"): raise ValueError("Mode 'dense' is only available for GPU calculations!") # We cannot test this with our CI at the moment obs.calc_dense_baselines() # pragma: no cover bas = obs.dense_baselines_gpu # pragma: nocover else: raise ValueError(f"Unsupported mode: {mode}") if batch_size == "auto": batch_size = bas.baseline_nums.shape[0] visibilities = adaptive_batch_size( _batch_loop, batch_size, visibilities=visibilities, vis_num=vis_num, obs=obs, B=B, bas=bas, lm=lm, rd=rd, noise_level=noise_level, noise_mode=noise_mode, telescope=telescope, band=band, show_progress=show_progress, mode=mode, ft=ft, ) visibilities.linear_dop = lin_dop.cpu() visibilities.circular_dop = circ_dop.cpu() return visibilities
def _batch_loop( batch_size: int, visibilities, vis_num: torch.Tensor, obs, B: torch.Tensor, bas, lm: torch.Tensor, rd: torch.Tensor, noise_level: float, noise_mode: str, telescope: str, band: str | None, show_progress: bool, mode: str, ft: Literal["default", "finufft", "reversed"] = "default", ): """Main simulation loop of pyvisgen. Computes visibilities batchwise. Parameters ---------- batch_size : int Batch size for loop over Baselines dataclass object. visibilities : Visibilities Visibilities dataclass object. vis_num : torch.Tensor Number of visibilities. obs : Observation Observation class object. B : torch.tensor Stokes matrix containing stokes visibilities. bas : Baselines Baselines dataclass object. lm : torch.tensor lm grid. rd : torch.tensor rd grid. system_temp : float or bool Simulate noise based on system temperature with given value. If set to False, no noise is simulated. show_progress : bool If True, show a progress bar tracking the loop. mode : str Select one of `'full'`, `'grid'`, or `'dense'` to get all valid baselines, a grid of unique baselines, or dense baselines. ft : str, optional Sets the type of fourier transform used in the RIME. Choose one of ``'default'``, ``'finufft'`` (Flatiron Institute Nonuniform Fast Fourier Transform) or `'reversed'`. Default: ``'default'`` Returns ------- visibilities : Visibilities Visibilities dataclass object. """ batches = torch.arange(bas.baseline_nums.shape[0]).split(batch_size) batches = tqdm( batches, position=0, disable=not show_progress, desc="Computing visibilities", postfix=f"Batch size: {batch_size}", ) rime = RIMEScan(ft=ft, mode=mode, obs=obs, lm=lm, rd=rd) for p in batches: bas_p = bas[p] int_values = torch.cat( tensors=[ rime( B, bas_p, spw_low=wave_low, spw_high=wave_high, )[None] for wave_low, wave_high in zip(obs.waves_low, obs.waves_high) ] ) if int_values.numel() == 0: continue int_values = torch.swapaxes(int_values, 0, 1) # In case any row contains NaN int_values_nans = torch.isnan(int_values).any(dim=(1, 2, 3)) int_values = int_values[~int_values_nans] if noise_level != 0: noise, weights = generate_noise( int_values.shape, obs, noise_level, mode=noise_mode, el1_deg=bas_p.el1_valid, el2_deg=bas_p.el2_valid, telescope=telescope, band=band, ) int_values += noise else: weights = torch.ones(int_values.shape[0], int_values.shape[1]) vis_num = torch.arange(int_values.shape[0]) + 1 + vis_num.max() vis = Visibilities( V_11=int_values[..., 0, 0].cpu(), V_22=int_values[..., 1, 1].cpu(), V_12=int_values[..., 0, 1].cpu(), V_21=int_values[..., 1, 0].cpu(), weights=weights.cpu(), num=vis_num, base_num=bas_p.baseline_nums[~int_values_nans].cpu(), u=bas_p.u_valid[~int_values_nans].cpu(), v=bas_p.v_valid[~int_values_nans].cpu(), w=bas_p.w_valid[~int_values_nans].cpu(), date=bas_p.date[~int_values_nans].cpu(), st_id_pairs=bas_p.st_id_pairs[~int_values_nans].cpu(), linear_dop=torch.tensor([]), circular_dop=torch.tensor([]), ) visibilities.add(vis) del int_values return visibilities