Source code for pyvisgen.simulation.observation

from dataclasses import dataclass, fields
from datetime import datetime

import astropy.units as un
import numpy as np
import torch
from astropy.constants import c
from astropy.coordinates import AltAz, Angle, EarthLocation, Longitude, SkyCoord
from astropy.time import Time
from tqdm.auto import tqdm

from pyvisgen.layouts import layouts
from pyvisgen.simulation.array import Array
from pyvisgen.utils.logging import setup_logger

torch.set_default_dtype(torch.float64)
LOGGER = setup_logger(namespace=__name__)

__all__ = ["Baselines", "ValidBaselineSubset", "Observation"]


DEFAULT_POL_KWARGS = {
    "delta": 0,
    "amp_ratio": 0.5,
    "random_state": 42,
}

DEFAULT_FIELD_KWARGS = {
    "order": [1, 1],
    "scale": [0, 1],
    "threshold": None,
    "random_state": 42,
}


[docs] @dataclass class Baselines: """The Baselines dataclass comprises of data on station combinations, the u, v, and w coverage, validity of the measured data points (i.e. whether the source is visible for the antenna pairs, or not), observation time and parallactic angles for each baseline pair. Attributes ---------- st1 : :func:`~torch.tensor` Station IDs for antenna pairs. st2 : :func:`~torch.tensor` Station IDs for antenna pairs. u : :func:`~torch.tensor` u coordinate coverage. v : :func:`~torch.tensor` v coordinate coverage. w : :func:`~torch.tensor` w coordinate coverage. valid : :func:`~torch.tensor` Mask of valid values, i.e. where the source is visible to the antenna pairs. time : :func:`~torch.tensor` Tensor of observation time steps. q1 : :func:`~torch.tensor` Tensor of parallactic angle values. q2 : :func:`~torch.tensor` Tensor of parallactic angle values. """ st1: torch.tensor st2: torch.tensor u: torch.tensor v: torch.tensor w: torch.tensor valid: torch.tensor time: torch.tensor q1: torch.tensor q2: torch.tensor def __getitem__(self, i): """Returns element at index ``i`` for all fields.""" return Baselines(*[getattr(self, f.name)[i] for f in fields(self)])
[docs] def add_baseline(self, baselines) -> None: """Adds a new baseline to the dataclass object. Parameters ---------- baselines : :class:`~pyvisgen.simulation.Baselines` :class:`~pyvisgen.simulation.Baselines` dataclass object that is added to the fields of this dataclass. """ [ setattr( self, f.name, torch.cat([getattr(self, f.name), getattr(baselines, f.name)]), ) for f in fields(self) ]
[docs] def get_valid_subset(self, num_baselines: int, device: str): """Returns a valid subset of the baselines using the information stored in the ``valid`` field. Parameters ---------- num_baselines : int Number of baselines used in the observation. device : str Name of the device to run the operation on, e.g. ``'cuda'`` or ``'cpu'``. Returns ------- ValidBaselineSubset :class:`~pyvisgen.simulation.ValidBaselineSubset` dataclass object containing valid u, v, and w coverage, observation time steps, numbers of baselines, and parallactic angles. """ bas_reshaped = Baselines( *[getattr(self, f.name).reshape(-1, num_baselines) for f in fields(self)] ) mask = (bas_reshaped.valid[:-1].bool()) & (bas_reshaped.valid[1:].bool()) baseline_nums = ( 256 * (bas_reshaped.st1[:-1][mask].ravel() + 1) + bas_reshaped.st2[:-1][mask].ravel() + 1 ).to(device) u_start = bas_reshaped.u[:-1][mask].to(device) v_start = bas_reshaped.v[:-1][mask].to(device) w_start = bas_reshaped.w[:-1][mask].to(device) u_stop = bas_reshaped.u[1:][mask].to(device) v_stop = bas_reshaped.v[1:][mask].to(device) w_stop = bas_reshaped.w[1:][mask].to(device) u_valid = (u_start + u_stop) / 2 v_valid = (v_start + v_stop) / 2 w_valid = (w_start + w_stop) / 2 q1_start = bas_reshaped.q1[:-1][mask].to(device) q2_start = bas_reshaped.q2[:-1][mask].to(device) q1_stop = bas_reshaped.q1[1:][mask].to(device) q2_stop = bas_reshaped.q2[1:][mask].to(device) q1_valid = (q1_start + q1_stop) / 2 q2_valid = (q2_start + q2_stop) / 2 t = Time(bas_reshaped.time / (60 * 60 * 24), format="mjd").jd date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) return ValidBaselineSubset( u_start, u_stop, u_valid, v_start, v_stop, v_valid, w_start, w_stop, w_valid, baseline_nums, date, q1_start, q1_stop, q1_valid, q2_start, q2_stop, q2_valid, )
[docs] @dataclass() class ValidBaselineSubset: """Valid baselines subset dataclass. Attributes ending on valid are all quantities where at least one baseline pair has contributed to the measurement of the source. Attributes ending on start are starting points for integration windows that end with attributes ending on stop. Attributes ---------- u_start : :func:`~torch.tensor` Start value for u coverage integration. u_stop : :func:`~torch.tensor` Stop value for u coverage integration. u_valid : :func:`~torch.tensor` Valid u values. v_start : :func:`~torch.tensor` Start value for v coverage integration. v_stop : :func:`~torch.tensor` Start value for v coverage integration. v_valid : :func:`~torch.tensor` Valid v values. w_start : :func:`~torch.tensor` Start value for w coverage integration. w_stop : :func:`~torch.tensor` Start value for w coverage integration. w_valid : :func:`~torch.tensor` Valid w values. baseline_nums : :func:`~torch.tensor` Numbers of baselines per time step. date : :func:`~torch.tensor` Time steps of the measurement during which at least one baseline pair contributed to the measurement. q1_start : :func:`~torch.tensor` q1_stop : :func:`~torch.tensor` q1_valid : :func:`~torch.tensor` Valid parallactic angle values (first half of the pair). q2_start : :func:`~torch.tensor` q2_stop : :func:`~torch.tensor` q2_valid : :func:`~torch.tensor` Valid parallactic angle values (second half of the pair). """ u_start: torch.tensor u_stop: torch.tensor u_valid: torch.tensor v_start: torch.tensor v_stop: torch.tensor v_valid: torch.tensor w_start: torch.tensor w_stop: torch.tensor w_valid: torch.tensor baseline_nums: torch.tensor date: torch.tensor q1_start: torch.tensor q1_stop: torch.tensor q1_valid: torch.tensor q2_start: torch.tensor q2_stop: torch.tensor q2_valid: torch.tensor def __getitem__(self, i): """Returns element at index ``i`` for all fields.""" return torch.stack( [ self.u_start, self.u_stop, self.u_valid, self.v_start, self.v_stop, self.v_valid, self.w_start, self.w_stop, self.w_valid, self.baseline_nums, self.date, self.q1_start, self.q1_stop, self.q1_valid, self.q2_start, self.q2_stop, self.q2_valid, ] )
[docs] def get_timerange(self, t_start, t_stop): """Returns all attributes that fall into the time range [``t_start``, ``t_stop``]. Parameters ---------- t_start : datetime Start date. t_stop : datetime End date. Returns ------- ValidBaselineSubset :class:`~pyvisgen.simulation.ValidBaselineSubset` dataclass object containing all attributes that fall in the time range between ``t_start`` and ``t_stop``. """ return ValidBaselineSubset( *[getattr(self, f.name).ravel() for f in fields(self)] )[(self.date >= t_start) & (self.date <= t_stop)]
[docs] def get_unique_grid( self, fov: float, ref_frequency: float, img_size: int, device: str, ): """Returns the unique grid for a given FOV, frequency, and image size. Parameters ---------- fov : float Size of the FOV. ref_frequency : float Reference frequency. img_size : int Size of the image. device : str Name of the device to run the operation on, e.g. ``'cuda'`` or ``'cpu'``. Returns ------- torch.tensor Tensor containing the unique grid for a given FOV, frequency, and image size. """ uv = torch.cat([self.u_valid[None], self.v_valid[None]], dim=0) fov = np.deg2rad(fov / 3600, dtype=np.float128) delta = fov ** (-1) * c.value / ref_frequency bins = torch.from_numpy( np.arange( start=-(img_size / 2 + 1 / 2) * delta, stop=(img_size / 2 + 1 / 2) * delta, step=delta, dtype=np.float128, ).astype(np.float64) ).to(device) if len(bins) - 1 > img_size: bins = bins[:-1] indices_bucket = torch.bucketize(uv, bins) indices_bucket_sort, indices_bucket_inv = self._lexsort(indices_bucket) indices_unique, indices_unique_inv, counts = torch.unique_consecutive( indices_bucket[:, indices_bucket_sort], dim=1, return_inverse=True, return_counts=True, ) _, ind_sorted = torch.sort(indices_unique_inv, stable=True) cum_sum = counts.cumsum(0) cum_sum = torch.cat((torch.tensor([0], device=device), cum_sum[:-1])) first_indices = ind_sorted[cum_sum] return self[:][:, indices_bucket_sort[first_indices]]
def _lexsort(self, a: torch.tensor, dim: int = -1) -> torch.tensor: """Sort a sequence of tensors in lexicographic order. Parameters ---------- a : torch.tensor Sequence of tensors to sort. dim : int, optional The dimension along which to sort. Default: ``-1`` """ assert dim == -1 # Transpose if you want differently assert a.ndim == 2 # Not sure what is numpy behaviour with > 2 dim # To be consistent with numpy, we flip the keys (sort by last row first) a_unq, inv = torch.unique(a.flip(0), dim=dim, sorted=True, return_inverse=True) return torch.argsort(inv), inv
@dataclass class Scan: """ Dataclass containing the timing information about a scan. A scan is defined as a period in which the simulated interferometer measures visibilities. It has a specific start and end time. The measurements are taken in regular time steps during the scan time. The difference between those steps is the integration time. If the total measurement time is not divisible by the integration time the last measurement will end prematurely and the time difference between the two last measurements will be shorter than the integration time. The scan seperation determines the time the interferometer does not measure between two scans. The time defined in one scan is the time that will be waited after (!) the scan. Attributes ---------- start: astropy.times.Time The start time of the scan stop: astropy.times.Time The stop time of the scan separation: float The separation to the next scan in seconds integration_time: float The integration time of the interferometer for this scan. """ start: Time stop: Time separation: float integration_time: float def get_num_timesteps(self): """ Get the number of time steps in the scan. Returns ------- num_timesteps: int """ return int( ((self.stop - self.start).to(un.second) // self.integration_time).value ) def get_timesteps(self): """ Get the start and stop times of the timesteps. Returns ------- timesteps: astropy.times.Time """ return Time( [ min([self.start + i * self.integration_time, self.stop]) for i in range(0, self.get_num_timesteps() + 1) ] )
[docs] class Observation: """Main observation simulation class. The :class:`~pyvisgen.simulation.Observation` class simulates the baselines and time steps during the observation. Parameters ---------- src_ra : float Source right ascension coordinate. src_dec : float Source declination coordinate. start_time : datetime Observation start time. scan_duration : int Scan duration. num_scans : int Number of scans. scan_separation : int Scan separation. integration_time : int Integration time. ref_frequency : float Reference frequency. frequency_offsets : list Frequency offsets. bandwidths : list Frequency bandwidth. fov : float Field of view. image_size : int Image size of the sky distribution. array_layout : str Name of an existing array layout. See :mod:`~pyvisgen.layouts`. corrupted : bool If ``True``, apply corruption during the vis loop. device : str Torch device to select for computation. dense : bool, optional If ``True``, apply dense baseline calculation of a perfect interferometer. Default: ``False`` sensitivity_cut : float, optional Sensitivity threshold, where only pixels above the value are kept. Default: ``1e-6`` polarization : str, optional Choose between ``'linear'`` or ``'circular'`` or ``None`` to simulate different types of polarizations or disable the simulation of polarization. Default: ``None`` pol_kwargs : dict, optional Additional keyword arguments for the simulation of polarization. Default: ``{'delta': 0,'amp_ratio': 0.5,'random_state': 42}`` field_kwargs : dict, optional Additional keyword arguments for the random polarization field that is applied when simulating polarization. Default: ``{'order': [1, 1],'scale': [0, 1],'threshold': None,'random_state': 42}`` show_progress : bool, optional If ``True``, show a progress bar during the iteration over the scans. Default: ``False`` Notes ----- See :class:`~pyvisgen.simulation.polarization` and :class:`~pyvisgen.simulation.polarization.rand_polarization_field` for more information on the keyword arguments in ``pol_kwargs`` and ``field_kwargs``, respectively. """ def __init__( self, src_ra: float, src_dec: float, start_time: datetime, scan_duration: int | np.typing.ArrayLike, num_scans: int, scan_separation: int | np.typing.ArrayLike, integration_time: int | np.typing.ArrayLike, ref_frequency: float, frequency_offsets: list, bandwidths: list, fov: float, image_size: int, array_layout: str, corrupted: bool, device: str, dense: bool = False, sensitivity_cut: float = 1e-6, polarization: str = None, pol_kwargs: dict = DEFAULT_POL_KWARGS, field_kwargs: dict = DEFAULT_FIELD_KWARGS, show_progress: bool = False, ) -> None: """Sets up the observation class. Parameters ---------- src_ra : float Source right ascension coordinate in degrees. src_dec : float Source declination coordinate in degrees. start_time : datetime Observation start time. scan_duration : int | ArrayLike Scan duration or array of scan durations with size=num_scans. num_scans : int Number of scans. scan_separation : int | ArrayLike Scan separation or array of scan seperations with size=num_scans - 1. integration_time : int | ArrayLike Integration time or an array of scan seperations with size=num_scans. ref_frequency : float Reference frequency in Hertz. frequency_offsets : list Frequency offsets in Hertz. bandwidths : list Frequency bandwidth in Hertz. fov : float Field of view in arcseconds. image_size : int Image size of the sky distribution. array_layout : str Name of an existing array layout. See :mod:`~pyvisgen.layouts`. corrupted : bool If ``True``, apply corruption during the vis loop. device : str Torch device to select for computation. dense : bool, optional If ``True``, apply dense baseline calculation of a perfect interferometer. Default: ``False`` sensitivity_cut : float, optional Sensitivity threshold, where only pixels above the value are kept. Default: ``1e-6`` polarization : str, optional Choose between ``'linear'`` or ``'circular'`` or ``None`` to simulate different types of polarizations or disable the simulation of polarization. Default: ``None`` pol_kwargs : dict, optional Additional keyword arguments for the simulation of polarization. Default: ``{'delta': 0,'amp_ratio': 0.5,'random_state': 42}`` field_kwargs : dict, optional Additional keyword arguments for the random polarization field that is applied when simulating polarization. Default: ``{'order': [1, 1],'scale': [0, 1],'threshold': None,'random_state': 42}`` show_progress : bool, optional If ``True``, show a progress bar during the iteration over the scans. Default: ``False`` Notes ----- See :class:`~pyvisgen.simulation.polarization` and :class:`~pyvisgen.simulation.Polarization.rand_polarization_field` for more information on the keyword arguments in ``pol_kwargs`` and ``field_kwargs``, respectively. """ self.ra = torch.tensor(src_ra).double() self.dec = torch.tensor(src_dec).double() self.show_progress = show_progress self.start = Time(start_time.isoformat(), format="isot", scale="utc") self.scan_duration = scan_duration # Duration of scans (in seconds) self.num_scans = num_scans # Number of scans self.scan_separation = scan_separation # Seperation between scans (in seconds) # Integration time (either one time for all scans or one for each scan) self.int_time = integration_time self.scans = self.create_scans() self.ref_frequency = torch.tensor(ref_frequency) self.bandwidths = torch.tensor(bandwidths) self.frequency_offsets = torch.tensor(frequency_offsets) self.waves_low = ( self.ref_frequency + self.frequency_offsets ) - self.bandwidths / 2 self.waves_high = ( self.ref_frequency + self.frequency_offsets ) + self.bandwidths / 2 self.fov = fov self.img_size = image_size self.pix_size = fov / image_size self.corrupted = corrupted self.sensitivity_cut = sensitivity_cut self.device = torch.device(device) self.layout = array_layout self.array = layouts.get_array_layout(array_layout) self.array_earth_loc = EarthLocation.from_geocentric( self.array.x, self.array.y, self.array.z, unit=un.m ) self.num_baselines = int( len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 ) if dense: # pragma: no cover self.waves_low = [self.ref_frequency] self.waves_high = [self.ref_frequency] self.calc_dense_baselines() self.ra = torch.tensor([0]).to(self.device) self.dec = torch.tensor([0]).to(self.device) else: self.calc_baselines() self.baselines.num = int( self.array.st_num.size(dim=0) * (self.array.st_num.size(dim=0) - 1) / 2 ) self.baselines.times_unique = torch.unique(self.baselines.time) self.rd = self.create_rd_grid() self.lm = self.create_lm_grid() # polarization self.polarization = polarization self.pol_kwargs = pol_kwargs self.field_kwargs = field_kwargs
[docs] def create_scans(self): """ Calculates individual scans of the observation based on the number of scans, their duration and the integration time. Returns ------- list[Scan]: List of scans with a specific start, stop and integration time """ scans = [] def _as_array(arr, size: int): if np.isscalar(arr): return np.ones(size) * arr else: return np.asarray(arr) current_time = self.start scan_duration = _as_array(self.scan_duration, size=self.num_scans) scan_separation = np.append( _as_array(self.scan_separation, size=self.num_scans - 1), 0 ) int_time = _as_array(self.int_time, size=self.num_scans) for i in range(self.num_scans): scans.append( Scan( start=current_time, stop=current_time + scan_duration[i] * un.second, integration_time=int_time[i] * un.second, separation=scan_separation[i] * un.second, ) ) current_time = scans[-1].stop + scans[-1].separation return scans
[docs] def calc_dense_baselines(self): # pragma: no cover """Calculates the baselines of a densely-built antenna array, which would provide full coverage of the uv space. """ N = self.img_size fov = np.deg2rad(self.fov / 3600, dtype=np.float128) delta = fov ** (-1) * c.value / self.ref_frequency u_dense = torch.from_numpy( np.arange( start=-(N / 2) * delta, stop=(N / 2) * delta, step=delta, dtype=np.float128, ).astype(np.float64) ).to(self.device) v_dense = u_dense uu, vv = torch.meshgrid(u_dense, v_dense, indexing="xy") u = uu.flatten() v = vv.flatten() self.dense_baselines_gpu = torch.stack( [ u, u, u, v, v, v, torch.zeros(u.shape, device=self.device), torch.zeros(u.shape, device=self.device), torch.zeros(u.shape, device=self.device), torch.ones(u.shape, device=self.device), torch.ones(u.shape, device=self.device), ] )
[docs] def calc_baselines(self): """Initializes :class:`~pyvisgen.simulation.Baselines` dataclass object and calls :py:func:`~pyvisgen.simulation.Observation.get_baselines` to compute the contents of the :class:`~pyvisgen.simulation.Baselines` dataclass. """ self.baselines = Baselines( torch.tensor([]), # st1 torch.tensor([]), # st2 torch.tensor([]), # u torch.tensor([]), # v torch.tensor([]), # w torch.tensor([]), # valid torch.tensor([]), # time torch.tensor([]), # q1 torch.tensor([]), # q2 ) for scan in tqdm( self.scans, disable=not self.show_progress, desc="Computing baselines", ): bas = self.get_baselines(scan.get_timesteps()) self.baselines.add_baseline(bas)
[docs] def get_baselines(self, times): """Calculates baselines from source coordinates and time of observation for every antenna station in array_layout. Parameters ---------- times : time object time of observation Returns ------- dataclass object baselines between telescopes with visibility flags """ # catch rare case where dimension of times is 0 if times.ndim == 0: times = Time([times]) # calculate GHA, local HA, and station elevation for all times. GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) ar = Array(self.array) delta_x, delta_y, delta_z = ar.calc_relative_pos st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals baselines = Baselines( torch.tensor([]), # st1 torch.tensor([]), # st2 torch.tensor([]), # u torch.tensor([]), # v torch.tensor([]), # w torch.tensor([]), # valid torch.tensor([]), # time torch.tensor([]), # q1 torch.tensor([]), # q2 ) q_all = self.calc_feed_rotation(ha_local) q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) # Loop over ha, el_st, times, parallactic angles for ha, el_st, time, qc in zip(GHA, el_st_all, times, q_comb): u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) # calc current elevations cur_el_st = torch.combinations(el_st) # calc valid baselines m1 = (cur_el_st < els_low_pairs).any(axis=1) m2 = (cur_el_st > els_high_pairs).any(axis=1) valid = torch.ones(u.shape).bool() valid_mask = torch.logical_or(m1, m2) valid[valid_mask] = False time_mjd = torch.repeat_interleave( torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) ) # collect baselines base = Baselines( st_num_pairs[..., 0], st_num_pairs[..., 1], u, v, w, valid, time_mjd, qc[..., 0].ravel(), qc[..., 1].ravel(), ) baselines.add_baseline(base) return baselines
[docs] def calc_ref_elev(self, time) -> tuple: """Calculates the station elevations for given time steps. Parameters ---------- time : array_like, optional Array containing observation time steps. Returns ------- tuple Tuple containing tensors of the Greenwich hour angle, antenna-local hour angles, and the elevations. """ if time.shape == (): time = time[None] src_crd = SkyCoord( ra=self.ra.numpy(), dec=self.dec.numpy(), unit=(un.deg, un.deg) ) # Calculate for all times # calculate GHA, Greenwich as reference GHA = time.sidereal_time("apparent", "greenwich") - src_crd.ra.to(un.hourangle) # calculate local sidereal time and HA at each antenna lst = un.Quantity( [ Time(time, location=loc).sidereal_time("mean") for loc in self.array_earth_loc ] ) ha_local = torch.from_numpy( (lst - Longitude(self.ra.item(), unit=un.deg)).radian ).T # calculate elevations el_st_all = src_crd.transform_to( AltAz( obstime=time[..., None], location=EarthLocation.from_geocentric( torch.repeat_interleave(self.array.x[None], len(time), dim=0), torch.repeat_interleave(self.array.y[None], len(time), dim=0), torch.repeat_interleave(self.array.z[None], len(time), dim=0), unit=un.m, ), ) ) if not len(GHA.value) == len(el_st_all): raise ValueError( "Expected GHA and el_st_all to have the same length" f"{len(GHA.value)} and {len(el_st_all)}" ) return ( torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree), )
[docs] def calc_feed_rotation(self, ha: Angle) -> Angle: r"""Calculates feed rotation for every antenna at every time step. Notes ----- The calculation is based on Equation (13.1) of Meeus' Astronomical Algorithms: .. math:: q = \atan\left(\frac{\sin h}{\cos\delta \tan\varphi - \sin\delta \cos h\right), where $h$ is the local hour angle, $\varphi$ the geographical latitude of the observer, and $\delta$ the declination of the source. """ # We need to create a tensor from the EarthLocation object # and save only the geographical latitude of each antenna ant_lat = torch.tensor(self.array_earth_loc.lat) # Eqn (13.1) of Meeus' Astronomical Algorithms q = torch.arctan2( torch.sin(ha), ( torch.tan(ant_lat) * torch.cos(self.dec) - torch.sin(self.dec) * torch.cos(ha) ), ) return q
[docs] def create_rd_grid(self): """Calculates RA and Dec values for a given fov around a source position Parameters ---------- fov : float FOV size samples : int number of pixels src_ra : right ascensio of the source in deg src_dec : dec of the source in deg Returns ------- rd_grid : 3d array Returns a 3d array with every pixel containing a RA and Dec value """ # transform to rad fov = np.deg2rad(self.fov / 3600, dtype=np.float128) # define resolution res = fov / self.img_size dec = torch.deg2rad(self.dec).to(self.device) r = torch.from_numpy( np.arange( start=-(self.img_size / 2) * res, stop=(self.img_size / 2) * res, step=res, dtype=np.float128, ).astype(np.float64) ).to(self.device) d = r + dec R, _ = torch.meshgrid((r, r), indexing="xy") _, D = torch.meshgrid((d, d), indexing="xy") rd_grid = torch.cat([R[..., None], D[..., None]], dim=2) return rd_grid
[docs] def create_lm_grid(self): """Calculates sine projection for fov Parameters ---------- rd_grid : 3d array array containing a RA and Dec value in every pixel src_crd : astropy SkyCoord source position Returns ------- lm_grid : 3d array Returns a 3d array with every pixel containing an l and m value """ dec = np.deg2rad(self.dec.cpu().numpy()).astype(np.float128) rd = self.rd.cpu().numpy().astype(np.float128) lm_grid = np.zeros(rd.shape, dtype=np.float128) lm_grid[..., 0] = np.cos(rd[..., 1]) * np.sin(rd[..., 0]) lm_grid[..., 1] = np.sin(rd[..., 1]) * np.cos(dec) - np.cos( rd[..., 1] ) * np.sin(dec) * np.cos(rd[..., 0]) return torch.from_numpy(lm_grid.astype(np.float64)).to(self.device)
[docs] def calc_direction_cosines( self, ha: torch.tensor, el_st: torch.tensor, delta_x: torch.tensor, delta_y: torch.tensor, delta_z: torch.tensor, ): """Calculates direction cosines u, v, and w for given hour angles and relative antenna positions. Parameters ---------- ha : :func:`torch.tensor` Tensor containing hour angles for each time step. el_st : :func:`torch.tensor` Tensor containing station elevations. delta_x : :func:`torch.tensor` Tensor containing relative antenna x-postions. delta_y : :func:`torch.tensor` Tensor containing relative antenna y-postions. delta_z : :func:`torch.tensor` Tensor containing relative antenna z-postions. Returns ------- u : :func:`torch.tensor` Tensor containing direction cosines in u-axis direction. v : :func:`torch.tensor` Tensor containing direction cosines in v-axis direction. w : :func:`torch.tensor` Tensor containing direction cosines in w-axis direction. """ src_dec = torch.deg2rad(self.dec) ha = torch.deg2rad(ha) u = (torch.sin(ha) * delta_x + torch.cos(ha) * delta_y).reshape(-1) v = ( -torch.sin(src_dec) * torch.cos(ha) * delta_x + torch.sin(src_dec) * torch.sin(ha) * delta_y + torch.cos(src_dec) * delta_z ).reshape(-1) w = ( torch.cos(src_dec) * torch.cos(ha) * delta_x - torch.cos(src_dec) * torch.sin(ha) * delta_y + torch.sin(src_dec) * delta_z ).reshape(-1) if not (u.shape == v.shape == w.shape): raise ValueError( "Expected u, v, and w to have the same shapes " f"but got {u.shape}, {v.shape}, and {w.shape}." ) return u, v, w