from __future__ import annotations
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING
import scipy.ndimage
import torch
from tqdm.auto import tqdm
from pyvisgen.simulation.noise import generate_noise
from pyvisgen.simulation.scan import RIMEScan
from pyvisgen.utils.batch_size import adaptive_batch_size
from pyvisgen.utils.logging import setup_logger
if TYPE_CHECKING:
from typing import Literal
torch.set_default_dtype(torch.float64)
LOGGER = setup_logger(namespace=__name__)
__all__ = [
"Visibilities",
"vis_loop",
"Polarization",
"generate_noise",
]
[docs]
@dataclass
class Visibilities:
"""Visibilities dataclass.
Attributes
----------
V_11 : :func:`~torch.tensor`
V_22 : :func:`~torch.tensor`
V_12 : :func:`~torch.tensor`
V_21 : :func:`~torch.tensor`
weights : :func:`~torch.tensor`
num : :func:`~torch.tensor`
base_num : :func:`~torch.tensor`
u : :func:`~torch.tensor`
v : :func:`~torch.tensor`
w : :func:`~torch.tensor`
date : :func:`~torch.tensor`
linear_dop : :func:`~torch.tensor`
circular_dop : :func:`~torch.tensor`
"""
V_11: torch.Tensor
V_22: torch.Tensor
V_12: torch.Tensor
V_21: torch.Tensor
weights: torch.Tensor
num: torch.Tensor
base_num: torch.Tensor
u: torch.Tensor
v: torch.Tensor
w: torch.Tensor
date: torch.Tensor
st_id_pairs: torch.Tensor
linear_dop: torch.Tensor
circular_dop: torch.Tensor
def __getitem__(self, i):
return Visibilities(*[getattr(self, f.name)[i] for f in fields(self)])
[docs]
def get_values(self):
return torch.cat(
[self.V_11[None], self.V_22[None], self.V_12[None], self.V_21[None]], dim=0
).permute(1, 2, 0)
[docs]
def add(self, visibilities):
[
setattr(
self,
f.name,
torch.cat([getattr(self, f.name), getattr(visibilities, f.name)]),
)
for f in fields(self)
]
[docs]
class Polarization:
r"""Simulation of polarization.
Creates the :math:`2\times 2` stokes matrix and simulates
polarization if ``polarization`` is either ``'linear'``
or ``'circular'``. Also computes the degree of polarization.
Parameters
----------
SI : :func:`~torch.tensor`
Stokes I component, i.e. intensity distribution
of the sky.
sensitivity_cut : float
Sensitivity cut, where only pixels above the value
are kept.
amp_ratio : float
Sets the ratio of :math:`A_{X|R}`. The ratio of :math:`A_{Y|L}`
is calculated as ``1 - amp_ratio``. If set to ``None``,
a random value is drawn from a uniform distribution.
See also: ``random_state``.
delta : float
Sets the phase difference of the amplitudes :math:`A_{X|R}`
and :math:`A_{Y|L}`` of the sky distribution. Defines the
measure of ellipticity.
polarization : str
Choose between ``'linear'`` or ``'circular'`` or ``None`` to
simulate different types of polarizations or disable
the simulation of polarization entirely.
random_state : int
Random state used when drawing ``amp_ratio`` and during
the generation of the random polarization field.
device : :class:`~torch.cuda.device`
Torch device to select for computation.
"""
def __init__(
self,
SI: torch.Tensor,
sensitivity_cut: float,
amp_ratio: float,
delta: float | torch.Tensor,
polarization: str,
field_kwargs: dict,
random_state: int,
device: torch.device,
) -> None:
"""Creates the :math:`2\times 2` stokes matrix and simulates
polarization if ``polarization`` is either ``'linear'``
or ``'circular'``. Also computes the degree of polarization.
Parameters
----------
SI : :func:`~torch.tensor`
Stokes I component, i.e. intensity distribution
of the sky.
sensitivity_cut : float
Sensitivity cut, where only pixels above the value
are kept.
amp_ratio : float
Sets the ratio of :math:`A_{X|R}`. The ratio of :math:`A_{Y|L}`
is calculated as ``1 - amp_ratio``. If set to ``None``,
a random value is drawn from a uniform distribution.
See also: ``random_state``.
delta : float
Sets the phase difference of the amplitudes :math:`A_{X|R}`
and :math:`A_{Y|L}`` of the sky distribution. Defines the
measure of ellipticity.
polarization : str
Choose between ``'linear'`` or ``'circular'`` or ``None`` to
simulate different types of polarizations or disable
the simulation of polarization entirely.
random_state : int
Random state used when drawing ``amp_ratio`` and during
the generation of the random polarization field.
device : :class:`~torch.cuda.device`
Torch device to select for computation.
"""
self.sensitivity_cut = sensitivity_cut
self.polarization = polarization
self.device = device
self.SI = SI.permute(dims=(1, 2, 0))
if random_state:
torch.manual_seed(random_state)
if self.polarization and self.polarization in ["circular", "linear"]:
self.polarization_field = self.rand_polarization_field(
[self.SI.shape[0], self.SI.shape[1]],
**field_kwargs,
)
if isinstance(delta, (float, int)):
delta = torch.tensor(delta)
self.delta = delta
ax2 = amp_ratio if amp_ratio and amp_ratio >= 0 else torch.rand(1)
if isinstance(ax2, torch.Tensor):
ax2 = ax2.to(self.device)
ay2 = 1 - ax2
self.ax2 = self.SI[..., 0].detach().clone().to(self.device) * ax2
self.ay2 = self.SI[..., 0].detach().clone().to(self.device) * ay2
else:
self.ax2 = self.SI[..., 0]
self.ay2 = torch.zeros_like(self.ax2)
self.I = torch.zeros(
(self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble
) # noqa: E741
[docs]
def linear(self) -> None:
r"""Computes the stokes parameters I, Q, U, and V
for linear polarization.
This is done using the following equations:
.. math::
I &= A_X^2 + A_Y^2 \\
Q &= A_X^2 - A_Y^2 \\
U &= 2A_X A_Y \cos\delta_{XY} \\
V &= -2A_X A_Y \sin\delta_{XY}
"""
self.I[..., 0] = self.ax2 + self.ay2
self.I[..., 1] = self.ax2 - self.ay2
self.I[..., 2] = (
2
* torch.sqrt(self.ax2)
* torch.sqrt(self.ay2)
* torch.cos(torch.deg2rad(self.delta))
)
self.I[..., 3] = (
-2
* torch.sqrt(self.ax2)
* torch.sqrt(self.ay2)
* torch.sin(torch.deg2rad(self.delta))
)
[docs]
def circular(self) -> None:
r"""Computes the stokes parameters I, Q, U, and V
for circular polarization.
This is done using the following equations:
.. math::
I &= A_R^2 + A_L^2 \\
Q &= 2A_R A_L \cos\delta_{RL} \\
U &= -2A_R A_L \sin\delta_{RL} \\
V &= A_R^2 - A_L^2
"""
self.I[..., 0] = self.ax2 + self.ay2
self.I[..., 1] = (
2
* torch.sqrt(self.ax2)
* torch.sqrt(self.ay2)
* torch.cos(torch.deg2rad(self.delta))
)
self.I[..., 2] = (
-2
* torch.sqrt(self.ax2)
* torch.sqrt(self.ay2)
* torch.sin(torch.deg2rad(self.delta))
)
self.I[..., 3] = self.ax2 - self.ay2
[docs]
def dop(self) -> None:
"""Computes the degree of polarization for each pixel."""
mask = (self.ax2 + self.ay2) > 0
# apply polarization_field to Q, U, and V only
self.I[..., 1] *= self.polarization_field
self.I[..., 2] *= self.polarization_field
self.I[..., 3] *= self.polarization_field
dop_I = self.I[..., 0].real.detach().clone()
dop_I[~mask] = float("nan")
dop_Q = self.I[..., 1].real.detach().clone()
dop_Q[~mask] = float("nan")
dop_U = self.I[..., 2].real.detach().clone()
dop_U[~mask] = float("nan")
dop_V = self.I[..., 3].real.detach().clone()
dop_V[~mask] = float("nan")
self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I
self.circ_dop = torch.abs(dop_V) / dop_I
del dop_I, dop_Q, dop_U, dop_V
[docs]
def stokes_matrix(self) -> tuple:
"""Computes and returns the 2 x 2 stokes matrix B.
Returns
-------
B : torch.tensor
2 x 2 stokes brightness matrix. Either for linear,
circular or no polarization.
mask : torch.tensor
Mask of the sensitivity cut (Keep all px > sensitivity_cut).
lin_dop : torch.tensor
Degree of linear polarization of every pixel in the sky.
circ_dop : torch.tensor
Degree of circular polarization of every pixel in the sky.
"""
# define 2 x 2 Stokes matrix
B = torch.zeros(
(self.SI.shape[0], self.SI.shape[1], 2, 2), dtype=torch.cdouble
).to(torch.device(self.device))
if self.polarization == "linear":
self.linear()
self.dop()
B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q
B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV
B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV
B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q
elif self.polarization == "circular":
self.circular()
self.dop()
B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] # I + V
B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] # Q + iU
B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] # Q - iU
B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] # I - V
else:
# No polarization applied
self.I[..., 0] = self.SI[..., 0]
self.polarization_field = torch.ones_like(self.I[..., 0])
self.dop()
B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q
B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV
B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV
B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q
# calculations only for px > sensitivity cut
mask = (self.sensitivity_cut <= self.SI)[..., 0]
B = B[mask]
return B, mask, self.lin_dop, self.circ_dop
[docs]
def rand_polarization_field(
self,
shape: list[int] | int,
order: list[int] | int = 1,
random_state: int | None = None,
scale: list | None = None,
threshold: float | None = None,
) -> torch.Tensor:
"""
Generates a random noise mask for polarization.
Parameters
----------
shape : array_like (M, N), or int
The size of the sky image.
order : array_like (M, N) or int, optional
Morphology of the random noise. Higher values create
more and smaller fluctuations. Default: ``1``.
random_state : int, optional
Random state for the random number generator. If ``None``,
a random entropy is pulled from the OS. Default: ``None``.
scale : array_like, optional
Scaling of the distribution of the image. Default: ``[0, 1]``
threshold : float, optional
If not None, an upper threshold is applied to the image.
Default: ``None``
Returns
-------
im : torch.tensor
An array containing random noise values between
scale[0] and scale[1].
"""
if random_state:
torch.manual_seed(random_state)
if isinstance(shape, int):
shape = [shape]
if not isinstance(shape, list):
shape = list(shape)
if len(shape) < 2:
shape *= 2
elif len(shape) > 2:
raise ValueError("Expected len of 'shape' to be 2!")
if isinstance(order, int | float):
order = [order]
if not isinstance(order, list):
order = list(order)
if len(order) < 2:
order *= 2
elif len(order) > 2:
raise ValueError("Expected len of 'order' to be 2!")
sigma = torch.mean(torch.tensor(shape).double()) / (40 * torch.tensor(order))
im = torch.rand(shape)
im = scipy.ndimage.gaussian_filter(im, sigma=sigma.numpy())
if scale is None:
scale = [im.min(), im.max()]
if len(scale) != 2:
raise ValueError("Expected len of 'scale' to be 2!")
im_flatten = torch.from_numpy(im.flatten())
im_argsort = torch.argsort(torch.argsort(im_flatten))
im_linspace = torch.linspace(*scale, im_argsort.size()[0])
uniform_flatten = im_linspace[im_argsort]
im = torch.reshape(uniform_flatten, im.shape)
if threshold:
im = im[im < threshold]
return im
[docs]
def vis_loop(
obs,
SI: torch.Tensor,
num_threads: int = 10,
noise_level: float = 0,
noise_mode: str = "sefd",
telescope: str = "meerkat",
band: str | None = None,
mode: str = "full",
batch_size: int | Literal["auto"] = "auto",
show_progress: bool = False,
normalize: bool = True,
ft: Literal["default", "finufft", "reversed"] = "default",
) -> Visibilities:
r"""Computes the visibilities of an observation.
Parameters
----------
obs : Observation class object
Observation class object generated by the
`~pyvisgen.simulation.Observation` class.
SI : torch.tensor
Tensor containing the sky intensity distribution.
num_threads : int, optional
Number of threads used for intraoperative parallelism
on the CPU. See `~torch.set_num_threads`. Default: 10
noise_level : float, optional
Noise amplitude: SEFD in Jy when ``noise_mode='sefd'``,
or T_sys/η in K when ``noise_mode='tsys'``. Set to 0 to disable noise.
Default: 0
noise_mode : str, optional
``'sefd'``: uniform SEFD noise (backward compatible, no elevation dependence).
``'tsys'``: elevation-dependent noise from system temperature.
Default: ``'sefd'``
telescope : str, optional
Telescope name for elevation-dependent Tsys corrections.
Only used when ``noise_mode='tsys'``. Default: ``'meerkat'``
mode : str, optional
Select one of `'full'`, `'grid'`, or `'dense'` to get
all valid baselines, a grid of unique baselines, or
dense baselines. Default: 'full'
batch_size : int, optional
Batch size for iteration over baselines. Default: 100
polarization : str, optional
Choose between `'linear'` or `'circular'` or `None` to
simulate different types of polarizations or disable
the simulation of polarization. Default: 'linear'
random_state : int, optional
Random state used when drawing `amp_ratio` and during the generation
of the random polarization field. Default: 42
show_progress : bool, optional
If `True`, show a progress bar during the iteration over the
batches of baselines. Default: False
normalize : bool, optional
If ``True``, normalize stokes matrix ``B`` by a factor 0.5.
Default: ``True``
ft : str, optional
Sets the type of fourier transform used in the RIME.
Choose one of ``'default'``, ``'finufft'`` (Flatiron Institute
Nonuniform Fast Fourier Transform) or `'reversed'`.
Default: ``'default'``
Returns
-------
visibilities : Visibilities
Dataclass object containing visibilities and baselines.
"""
torch.set_num_threads(num_threads)
torch._dynamo.config.suppress_errors = True
if not (
isinstance(batch_size, int)
or (isinstance(batch_size, str) and batch_size == "auto")
):
raise ValueError("Expected batch_size to be 'auto' or type int")
pol = Polarization(
SI,
sensitivity_cut=obs.sensitivity_cut,
polarization=obs.polarization,
device=obs.device,
field_kwargs=obs.field_kwargs,
**obs.pol_kwargs,
)
B, mask, lin_dop, circ_dop = pol.stokes_matrix()
lm = obs.lm[mask]
rd = obs.rd[mask]
# normalize visibilities to factor 0.5,
# so that the Stokes I image is normalized to 1
if normalize:
B *= 0.5
# calculate vis
visibilities = Visibilities(
torch.empty(size=[0] + [len(obs.waves_low)]),
torch.empty(size=[0] + [len(obs.waves_low)]),
torch.empty(size=[0] + [len(obs.waves_low)]),
torch.empty(size=[0] + [len(obs.waves_low)]),
torch.tensor([]),
torch.tensor([]),
torch.tensor([]),
torch.tensor([]),
torch.tensor([]),
torch.tensor([]),
torch.tensor([]),
torch.empty(0, 2),
torch.tensor([]),
torch.tensor([]),
)
vis_num = torch.zeros(1)
if mode == "full":
bas = obs.baselines.get_valid_subset(obs.num_baselines, obs.device)
elif mode == "grid":
bas = obs.baselines.get_valid_subset(
obs.num_baselines, obs.device
).get_unique_grid(obs.fov, obs.ref_frequency, obs.img_size, obs.device)
elif mode == "dense":
if obs.device == torch.device("cpu"):
raise ValueError("Mode 'dense' is only available for GPU calculations!")
# We cannot test this with our CI at the moment
obs.calc_dense_baselines() # pragma: no cover
bas = obs.dense_baselines_gpu # pragma: nocover
else:
raise ValueError(f"Unsupported mode: {mode}")
if batch_size == "auto":
batch_size = bas.baseline_nums.shape[0]
visibilities = adaptive_batch_size(
_batch_loop,
batch_size,
visibilities=visibilities,
vis_num=vis_num,
obs=obs,
B=B,
bas=bas,
lm=lm,
rd=rd,
noise_level=noise_level,
noise_mode=noise_mode,
telescope=telescope,
band=band,
show_progress=show_progress,
mode=mode,
ft=ft,
)
visibilities.linear_dop = lin_dop.cpu()
visibilities.circular_dop = circ_dop.cpu()
return visibilities
def _batch_loop(
batch_size: int,
visibilities,
vis_num: torch.Tensor,
obs,
B: torch.Tensor,
bas,
lm: torch.Tensor,
rd: torch.Tensor,
noise_level: float,
noise_mode: str,
telescope: str,
band: str | None,
show_progress: bool,
mode: str,
ft: Literal["default", "finufft", "reversed"] = "default",
):
"""Main simulation loop of pyvisgen. Computes visibilities
batchwise.
Parameters
----------
batch_size : int
Batch size for loop over Baselines dataclass object.
visibilities : Visibilities
Visibilities dataclass object.
vis_num : torch.Tensor
Number of visibilities.
obs : Observation
Observation class object.
B : torch.tensor
Stokes matrix containing stokes visibilities.
bas : Baselines
Baselines dataclass object.
lm : torch.tensor
lm grid.
rd : torch.tensor
rd grid.
system_temp : float or bool
Simulate noise based on system temperature with given value. If set to False,
no noise is simulated.
show_progress : bool
If True, show a progress bar tracking the loop.
mode : str
Select one of `'full'`, `'grid'`, or `'dense'` to get
all valid baselines, a grid of unique baselines, or
dense baselines.
ft : str, optional
Sets the type of fourier transform used in the RIME.
Choose one of ``'default'``, ``'finufft'`` (Flatiron Institute
Nonuniform Fast Fourier Transform) or `'reversed'`.
Default: ``'default'``
Returns
-------
visibilities : Visibilities
Visibilities dataclass object.
"""
batches = torch.arange(bas.baseline_nums.shape[0]).split(batch_size)
batches = tqdm(
batches,
position=0,
disable=not show_progress,
desc="Computing visibilities",
postfix=f"Batch size: {batch_size}",
)
rime = RIMEScan(ft=ft, mode=mode, obs=obs, lm=lm, rd=rd)
for p in batches:
bas_p = bas[p]
int_values = torch.cat(
tensors=[
rime(
B,
bas_p,
spw_low=wave_low,
spw_high=wave_high,
)[None]
for wave_low, wave_high in zip(obs.waves_low, obs.waves_high)
]
)
if int_values.numel() == 0:
continue
int_values = torch.swapaxes(int_values, 0, 1)
# In case any row contains NaN
int_values_nans = torch.isnan(int_values).any(dim=(1, 2, 3))
int_values = int_values[~int_values_nans]
if noise_level != 0:
noise, weights = generate_noise(
int_values.shape,
obs,
noise_level,
mode=noise_mode,
el1_deg=bas_p.el1_valid,
el2_deg=bas_p.el2_valid,
telescope=telescope,
band=band,
)
int_values += noise
else:
weights = torch.ones(int_values.shape[0], int_values.shape[1])
vis_num = torch.arange(int_values.shape[0]) + 1 + vis_num.max()
vis = Visibilities(
V_11=int_values[..., 0, 0].cpu(),
V_22=int_values[..., 1, 1].cpu(),
V_12=int_values[..., 0, 1].cpu(),
V_21=int_values[..., 1, 0].cpu(),
weights=weights.cpu(),
num=vis_num,
base_num=bas_p.baseline_nums[~int_values_nans].cpu(),
u=bas_p.u_valid[~int_values_nans].cpu(),
v=bas_p.v_valid[~int_values_nans].cpu(),
w=bas_p.w_valid[~int_values_nans].cpu(),
date=bas_p.date[~int_values_nans].cpu(),
st_id_pairs=bas_p.st_id_pairs[~int_values_nans].cpu(),
linear_dop=torch.tensor([]),
circular_dop=torch.tensor([]),
)
visibilities.add(vis)
del int_values
return visibilities