Source code for pyvisgen.simulation.observation

from dataclasses import dataclass, fields
from datetime import datetime

import astropy.units as un
import numpy as np
import torch
from astropy.constants import c
from astropy.coordinates import AltAz, Angle, EarthLocation, Longitude, SkyCoord
from astropy.time import Time
from tqdm.auto import tqdm

from pyvisgen.layouts import layouts
from pyvisgen.simulation.array import Array
from pyvisgen.utils.logging import setup_logger

torch.set_default_dtype(torch.float64)
LOGGER = setup_logger()

__all__ = ["Baselines", "ValidBaselineSubset", "Observation"]


DEFAULT_POL_KWARGS = {
    "delta": 0,
    "amp_ratio": 0.5,
    "random_state": 42,
}

DEFAULT_FIELD_KWARGS = {
    "order": [1, 1],
    "scale": [0, 1],
    "threshold": None,
    "random_state": 42,
}


[docs] @dataclass class Baselines: """The Baselines dataclass comprises of data on station combinations, the u, v, and w coverage, validity of the measured data points (i.e. whether the source is visible for the antenna pairs, or not), observation time and parallactic angles for each baseline pair. Attributes ---------- st1 : :func:`~torch.tensor` Station IDs for antenna pairs. st2 : :func:`~torch.tensor` Station IDs for antenna pairs. u : :func:`~torch.tensor` u coordinate coverage. v : :func:`~torch.tensor` v coordinate coverage. w : :func:`~torch.tensor` w coordinate coverage. valid : :func:`~torch.tensor` Mask of valid values, i.e. where the source is visible to the antenna pairs. time : :func:`~torch.tensor` Tensor of observation time steps. q1 : :func:`~torch.tensor` Tensor of parallactic angle values. q2 : :func:`~torch.tensor` Tensor of parallactic angle values. """ st1: torch.tensor st2: torch.tensor u: torch.tensor v: torch.tensor w: torch.tensor valid: torch.tensor time: torch.tensor q1: torch.tensor q2: torch.tensor def __getitem__(self, i): """Returns element at index ``i`` for all fields.""" return Baselines(*[getattr(self, f.name)[i] for f in fields(self)])
[docs] def add_baseline(self, baselines) -> None: """Adds a new baseline to the dataclass object. Parameters ---------- baselines : :class:`~pyvisgen.simulation.Baselines` :class:`~pyvisgen.simulation.Baselines` dataclass object that is added to the fields of this dataclass. """ [ setattr( self, f.name, torch.cat([getattr(self, f.name), getattr(baselines, f.name)]), ) for f in fields(self) ]
[docs] def get_valid_subset(self, num_baselines: int, device: str): """Returns a valid subset of the baselines using the information stored in the ``valid`` field. Parameters ---------- num_baselines : int Number of baselines used in the observation. device : str Name of the device to run the operation on, e.g. ``'cuda'`` or ``'cpu'``. Returns ValidBaselineSubset :class:`~pyvisgen.simulation.ValidBaselineSubset` dataclass object containing valid u, v, and w coverage, observation time steps, numbers of baselines, and parallactic angles. """ bas_reshaped = Baselines( *[getattr(self, f.name).reshape(-1, num_baselines) for f in fields(self)] ) mask = (bas_reshaped.valid[:-1].bool()) & (bas_reshaped.valid[1:].bool()) baseline_nums = ( 256 * (bas_reshaped.st1[:-1][mask].ravel() + 1) + bas_reshaped.st2[:-1][mask].ravel() + 1 ).to(device) u_start = bas_reshaped.u[:-1][mask].to(device) v_start = bas_reshaped.v[:-1][mask].to(device) w_start = bas_reshaped.w[:-1][mask].to(device) u_stop = bas_reshaped.u[1:][mask].to(device) v_stop = bas_reshaped.v[1:][mask].to(device) w_stop = bas_reshaped.w[1:][mask].to(device) u_valid = (u_start + u_stop) / 2 v_valid = (v_start + v_stop) / 2 w_valid = (w_start + w_stop) / 2 q1_start = bas_reshaped.q1[:-1][mask].to(device) q2_start = bas_reshaped.q2[:-1][mask].to(device) q1_stop = bas_reshaped.q1[1:][mask].to(device) q2_stop = bas_reshaped.q2[1:][mask].to(device) q1_valid = (q1_start + q1_stop) / 2 q2_valid = (q2_start + q2_stop) / 2 t = Time(bas_reshaped.time / (60 * 60 * 24), format="mjd").jd date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) return ValidBaselineSubset( u_start, u_stop, u_valid, v_start, v_stop, v_valid, w_start, w_stop, w_valid, baseline_nums, date, q1_start, q1_stop, q1_valid, q2_start, q2_stop, q2_valid, )
[docs] @dataclass() class ValidBaselineSubset: """Valid baselines subset dataclass. Attributes ending on valid are all quantities where at least one baseline pair has contributed to the measurement of the source. Attributes ending on start are starting points for integration windows that end with attributes ending on stop. Attributes ---------- u_start : :func:`~torch.tensor` Start value for u coverage integration. u_stop : :func:`~torch.tensor` Stop value for u coverage integration. u_valid : :func:`~torch.tensor` Valid u values. v_start : :func:`~torch.tensor` Start value for v coverage integration. v_stop : :func:`~torch.tensor` Start value for v coverage integration. v_valid : :func:`~torch.tensor` Valid v values. w_start : :func:`~torch.tensor` Start value for w coverage integration. w_stop : :func:`~torch.tensor` Start value for w coverage integration. w_valid : :func:`~torch.tensor` Valid w values. baseline_nums : :func:`~torch.tensor` Numbers of baselines per time step. date : :func:`~torch.tensor` Time steps of the measurement during which at least one baseline pair contributed to the measurement. q1_start : :func:`~torch.tensor` q1_stop : :func:`~torch.tensor` q1_valid : :func:`~torch.tensor` Valid parallactic angle values (first half of the pair). q2_start : :func:`~torch.tensor` q2_stop : :func:`~torch.tensor` q2_valid : :func:`~torch.tensor` Valid parallactic angle values (second half of the pair). """ u_start: torch.tensor u_stop: torch.tensor u_valid: torch.tensor v_start: torch.tensor v_stop: torch.tensor v_valid: torch.tensor w_start: torch.tensor w_stop: torch.tensor w_valid: torch.tensor baseline_nums: torch.tensor date: torch.tensor q1_start: torch.tensor q1_stop: torch.tensor q1_valid: torch.tensor q2_start: torch.tensor q2_stop: torch.tensor q2_valid: torch.tensor def __getitem__(self, i): """Returns element at index ``i`` for all fields.""" return torch.stack( [ self.u_start, self.u_stop, self.u_valid, self.v_start, self.v_stop, self.v_valid, self.w_start, self.w_stop, self.w_valid, self.baseline_nums, self.date, self.q1_start, self.q1_stop, self.q1_valid, self.q2_start, self.q2_stop, self.q2_valid, ] )
[docs] def get_timerange(self, t_start, t_stop): """Returns all attributes that fall into the time range [``t_start``, ``t_stop``]. Parameters ---------- t_start : datetime Start date. t_stop : datetime End date. Returns ------- ValidBaselineSubset :class:`~pyvisgen.simulation.ValidBaselineSubset` dataclass object containing all attributes that fall in the time range between ``t_start`` and ``t_stop``. """ return ValidBaselineSubset( *[getattr(self, f.name).ravel() for f in fields(self)] )[(self.date >= t_start) & (self.date <= t_stop)]
[docs] def get_unique_grid( self, fov: float, ref_frequency: float, img_size: int, device: str, ): """Returns the unique grid for a given FOV, frequency, and image size. Parameters ---------- fov : float Size of the FOV. ref_frequency : float Reference frequency. img_size : int Size of the image. device : str Name of the device to run the operation on, e.g. ``'cuda'`` or ``'cpu'``. Returns ------- torch.tensor Tensor containing the unique grid for a given FOV, frequency, and image size. """ uv = torch.cat([self.u_valid[None], self.v_valid[None]], dim=0) fov = np.deg2rad(fov / 3600, dtype=np.float128) delta = fov ** (-1) * c.value / ref_frequency bins = torch.from_numpy( np.arange( start=-(img_size / 2 + 1 / 2) * delta, stop=(img_size / 2 + 1 / 2) * delta, step=delta, dtype=np.float128, ).astype(np.float64) ).to(device) if len(bins) - 1 > img_size: bins = bins[:-1] indices_bucket = torch.bucketize(uv, bins) indices_bucket_sort, indices_bucket_inv = self._lexsort(indices_bucket) indices_unique, indices_unique_inv, counts = torch.unique_consecutive( indices_bucket[:, indices_bucket_sort], dim=1, return_inverse=True, return_counts=True, ) _, ind_sorted = torch.sort(indices_unique_inv, stable=True) cum_sum = counts.cumsum(0) cum_sum = torch.cat((torch.tensor([0], device=device), cum_sum[:-1])) first_indices = ind_sorted[cum_sum] return self[:][:, indices_bucket_sort[first_indices]]
def _lexsort(self, a: torch.tensor, dim: int = -1) -> torch.tensor: """Sort a sequence of tensors in lexicographic order. Parameters ---------- a : torch.tensor Sequence of tensors to sort. dim : int, optional The dimension along which to sort. Default: ``-1`` """ assert dim == -1 # Transpose if you want differently assert a.ndim == 2 # Not sure what is numpy behaviour with > 2 dim # To be consistent with numpy, we flip the keys (sort by last row first) a_unq, inv = torch.unique(a.flip(0), dim=dim, sorted=True, return_inverse=True) return torch.argsort(inv), inv
[docs] class Observation: """Main observation simulation class. The :class:`~pyvisgen.simulation.Observation` class simulates the baselines and time steps during the observation. Parameters ---------- src_ra : float Source right ascension coordinate. src_dec : float Source declination coordinate. start_time : datetime Observation start time. scan_duration : int Scan duration. num_scans : int Number of scans. scan_separation : int Scan separation. integration_time : int Integration time. ref_frequency : float Reference frequency. frequency_offsets : list Frequency offsets. bandwidths : list Frequency bandwidth. fov : float Field of view. image_size : int Image size of the sky distribution. array_layout : str Name of an existing array layout. See :mod:`~pyvisgen.layouts`. corrupted : bool If ``True``, apply corruption during the vis loop. device : str Torch device to select for computation. dense : bool, optional If ``True``, apply dense baseline calculation of a perfect interferometer. Default: ``False`` sensitivity_cut : float, optional Sensitivity threshold, where only pixels above the value are kept. Default: ``1e-6`` polarization : str, optional Choose between ``'linear'`` or ``'circular'`` or ``None`` to simulate different types of polarizations or disable the simulation of polarization. Default: ``None`` pol_kwargs : dict, optional Additional keyword arguments for the simulation of polarization. Default: ``{'delta': 0,'amp_ratio': 0.5,'random_state': 42}`` field_kwargs : dict, optional Additional keyword arguments for the random polarization field that is applied when simulating polarization. Default: ``{'order': [1, 1],'scale': [0, 1],'threshold': None,'random_state': 42}`` show_progress : bool, optional If ``True``, show a progress bar during the iteration over the scans. Default: ``False`` Notes ----- See :class:`~pyvisgen.simulation.polarization` and :class:`~pyvisgen.simulation.polarization.rand_polarization_field` for more information on the keyword arguments in ``pol_kwargs`` and ``field_kwargs``, respectively. """ def __init__( self, src_ra: float, src_dec: float, start_time: datetime, scan_duration: int, num_scans: int, scan_separation: int, integration_time: int, ref_frequency: float, frequency_offsets: list, bandwidths: list, fov: float, image_size: int, array_layout: str, corrupted: bool, device: str, dense: bool = False, sensitivity_cut: float = 1e-6, polarization: str = None, pol_kwargs: dict = DEFAULT_POL_KWARGS, field_kwargs: dict = DEFAULT_FIELD_KWARGS, show_progress: bool = False, ) -> None: """Sets up the observation class. Parameters ---------- src_ra : float Source right ascension coordinate. src_dec : float Source declination coordinate. start_time : datetime Observation start time. scan_duration : int Scan duration. num_scans : int Number of scans. scan_separation : int Scan separation. integration_time : int Integration time. ref_frequency : float Reference frequency. frequency_offsets : list Frequency offsets. bandwidths : list Frequency bandwidth. fov : float Field of view in arcseconds. image_size : int Image size of the sky distribution. array_layout : str Name of an existing array layout. See :mod:`~pyvisgen.layouts`. corrupted : bool If ``True``, apply corruption during the vis loop. device : str Torch device to select for computation. dense : bool, optional If ``True``, apply dense baseline calculation of a perfect interferometer. Default: ``False`` sensitivity_cut : float, optional Sensitivity threshold, where only pixels above the value are kept. Default: ``1e-6`` polarization : str, optional Choose between ``'linear'`` or ``'circular'`` or ``None`` to simulate different types of polarizations or disable the simulation of polarization. Default: ``None`` pol_kwargs : dict, optional Additional keyword arguments for the simulation of polarization. Default: ``{'delta': 0,'amp_ratio': 0.5,'random_state': 42}`` field_kwargs : dict, optional Additional keyword arguments for the random polarization field that is applied when simulating polarization. Default: ``{'order': [1, 1],'scale': [0, 1],'threshold': None,'random_state': 42}`` show_progress : bool, optional If ``True``, show a progress bar during the iteration over the scans. Default: ``False`` Notes ----- See :class:`~pyvisgen.simulation.polarization` and :class:`~pyvisgen.simulation.Polarization.rand_polarization_field` for more information on the keyword arguments in ``pol_kwargs`` and ``field_kwargs``, respectively. """ self.ra = torch.tensor(src_ra).double() self.dec = torch.tensor(src_dec).double() self.start = Time(start_time.isoformat(), format="isot", scale="utc") self.scan_duration = scan_duration self.num_scans = num_scans self.int_time = integration_time self.scan_separation = scan_separation self.times, self.times_mjd = self.calc_time_steps() self.scans = torch.stack( torch.split( torch.arange(self.times.size), (self.times.size // self.num_scans), ), dim=0, ) self.ref_frequency = torch.tensor(ref_frequency) self.bandwidths = torch.tensor(bandwidths) self.frequency_offsets = torch.tensor(frequency_offsets) self.waves_low = ( self.ref_frequency + self.frequency_offsets ) - self.bandwidths / 2 self.waves_high = ( self.ref_frequency + self.frequency_offsets ) + self.bandwidths / 2 self.fov = fov self.img_size = image_size self.pix_size = fov / image_size self.corrupted = corrupted self.sensitivity_cut = sensitivity_cut self.device = torch.device(device) self.layout = array_layout self.array = layouts.get_array_layout(array_layout) self.array_earth_loc = EarthLocation.from_geocentric( self.array.x, self.array.y, self.array.z, unit=un.m ) self.num_baselines = int( len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 ) self.show_progress = show_progress if dense: # pragma: no cover self.waves_low = [self.ref_frequency] self.waves_high = [self.ref_frequency] self.calc_dense_baselines() self.ra = torch.tensor([0]).to(self.device) self.dec = torch.tensor([0]).to(self.device) else: self.calc_baselines() self.baselines.num = int( self.array.st_num.size(dim=0) * (self.array.st_num.size(dim=0) - 1) / 2 ) self.baselines.times_unique = torch.unique(self.baselines.time) self.rd = self.create_rd_grid() self.lm = self.create_lm_grid() # polarization self.polarization = polarization self.pol_kwargs = pol_kwargs self.field_kwargs = field_kwargs
[docs] def calc_time_steps(self): """Computes the time steps of the observation. Returns ------- time : array_like Array of time steps. time.mjd : array_like Time steps in mjd format. """ time_lst = [ self.start + self.scan_separation * i * un.second + i * self.scan_duration * un.second + j * self.int_time * un.second for i in range(self.num_scans) for j in range(int(self.scan_duration / self.int_time) + 1) ] # +1 because t_1 is the stop time of t_0. # In order to save computing power we take # one time more to complete interval time = Time(time_lst) return time, time.mjd * (60 * 60 * 24)
[docs] def calc_dense_baselines(self): # pragma: no cover """Calculates the baselines of a densely-built antenna array, which would provide full coverage of the uv space. """ N = self.img_size fov = np.deg2rad(self.fov / 3600, dtype=np.float128) delta = fov ** (-1) * c.value / self.ref_frequency u_dense = torch.from_numpy( np.arange( start=-(N / 2) * delta, stop=(N / 2) * delta, step=delta, dtype=np.float128, ).astype(np.float64) ).to(self.device) v_dense = u_dense uu, vv = torch.meshgrid(u_dense, v_dense) u = uu.flatten() v = vv.flatten() self.dense_baselines_gpu = torch.stack( [ u, u, u, v, v, v, torch.zeros(u.shape, device=self.device), torch.zeros(u.shape, device=self.device), torch.zeros(u.shape, device=self.device), torch.ones(u.shape, device=self.device), torch.ones(u.shape, device=self.device), ] )
[docs] def calc_baselines(self): """Initializes :class:`~pyvisgen.simulation.Baselines` dataclass object and calls :py:func:`~pyvisgen.simulation.Observation.get_baselines` to compute the contents of the :class:`~pyvisgen.simulation.Baselines` dataclass. """ self.baselines = Baselines( torch.tensor([]), # st1 torch.tensor([]), # st2 torch.tensor([]), # u torch.tensor([]), # v torch.tensor([]), # w torch.tensor([]), # valid torch.tensor([]), # time torch.tensor([]), # q1 torch.tensor([]), # q2 ) self.scans = tqdm( self.scans, disable=not self.show_progress, desc="Computing scans", ) for scan in self.scans: bas = self.get_baselines(self.times[scan]) self.baselines.add_baseline(bas)
[docs] def get_baselines(self, times): """Calculates baselines from source coordinates and time of observation for every antenna station in array_layout. Parameters ---------- times : time object time of observation Returns ------- dataclass object baselines between telescopes with visibility flags """ # catch rare case where dimension of times is 0 if times.ndim == 0: times = Time([times]) # calculate GHA, local HA, and station elevation for all times. GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) ar = Array(self.array) delta_x, delta_y, delta_z = ar.calc_relative_pos st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals baselines = Baselines( torch.tensor([]), # st1 torch.tensor([]), # st2 torch.tensor([]), # u torch.tensor([]), # v torch.tensor([]), # w torch.tensor([]), # valid torch.tensor([]), # time torch.tensor([]), # q1 torch.tensor([]), # q2 ) q_all = self.calc_feed_rotation(ha_local) q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) # Loop over ha, el_st, times, parallactic angles for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) # calc current elevations cur_el_st = torch.combinations(el_st) # calc valid baselines m1 = (cur_el_st < els_low_pairs).any(axis=1) m2 = (cur_el_st > els_high_pairs).any(axis=1) valid = torch.ones(u.shape).bool() valid_mask = torch.logical_or(m1, m2) valid[valid_mask] = False time_mjd = torch.repeat_interleave( torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) ) # collect baselines base = Baselines( st_num_pairs[..., 0], st_num_pairs[..., 1], u, v, w, valid, time_mjd, qc[..., 0].ravel(), qc[..., 1].ravel(), ) baselines.add_baseline(base) return baselines
[docs] def calc_ref_elev(self, time=None) -> tuple: """Calculates the station elevations for given time steps. Parameters ---------- time : array_like or None, optional Array containing observation time steps. Default: ``None`` Returns ------- tuple Tuple containing tensors of the Greenwich hour angle, antenna-local hour angles, and the elevations. """ if time is None: time = self.times if time.shape == (): time = time[None] src_crd = SkyCoord(ra=self.ra, dec=self.dec, unit=(un.deg, un.deg)) # Calculate for all times # calculate GHA, Greenwich as reference GHA = time.sidereal_time("apparent", "greenwich") - src_crd.ra.to(un.hourangle) # calculate local sidereal time and HA at each antenna lst = un.Quantity( [ Time(time, location=loc).sidereal_time("mean") for loc in self.array_earth_loc ] ) ha_local = torch.from_numpy( (lst - Longitude(self.ra.item(), unit=un.deg)).radian ).T # calculate elevations el_st_all = src_crd.transform_to( AltAz( obstime=time[..., None], location=EarthLocation.from_geocentric( torch.repeat_interleave(self.array.x[None], len(time), dim=0), torch.repeat_interleave(self.array.y[None], len(time), dim=0), torch.repeat_interleave(self.array.z[None], len(time), dim=0), unit=un.m, ), ) ) if not len(GHA.value) == len(el_st_all): raise ValueError( "Expected GHA and el_st_all to have the same length" f"{len(GHA.value)} and {len(el_st_all)}" ) return ( torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree), )
[docs] def calc_feed_rotation(self, ha: Angle) -> Angle: r"""Calculates feed rotation for every antenna at every time step. Notes ----- The calculation is based on Equation (13.1) of Meeus' Astronomical Algorithms: .. math:: q = \atan\left(\frac{\sin h}{\cos\delta \tan\varphi - \sin\delta \cos h\right), where $h$ is the local hour angle, $\varphi$ the geographical latitude of the observer, and $\delta$ the declination of the source. """ # We need to create a tensor from the EarthLocation object # and save only the geographical latitude of each antenna ant_lat = torch.tensor(self.array_earth_loc.lat) # Eqn (13.1) of Meeus' Astronomical Algorithms q = torch.arctan2( torch.sin(ha), ( torch.tan(ant_lat) * torch.cos(self.dec) - torch.sin(self.dec) * torch.cos(ha) ), ) return q
[docs] def create_rd_grid(self): """Calculates RA and Dec values for a given fov around a source position Parameters ---------- fov : float FOV size samples : int number of pixels src_ra : right ascensio of the source in deg src_dec : dec of the source in deg Returns ------- rd_grid : 3d array Returns a 3d array with every pixel containing a RA and Dec value """ # transform to rad fov = np.deg2rad(self.fov / 3600, dtype=np.float128) # define resolution res = fov / self.img_size dec = torch.deg2rad(self.dec).to(self.device) r = torch.from_numpy( np.arange( start=-(self.img_size / 2) * res, stop=(self.img_size / 2) * res, step=res, dtype=np.float128, ).astype(np.float64) ).to(self.device) d = r + dec R, _ = torch.meshgrid((r, r), indexing="ij") _, D = torch.meshgrid((d, d), indexing="ij") rd_grid = torch.cat([R[..., None], D[..., None]], dim=2) return rd_grid
[docs] def create_lm_grid(self): """Calculates sine projection for fov Parameters ---------- rd_grid : 3d array array containing a RA and Dec value in every pixel src_crd : astropy SkyCoord source position Returns ------- lm_grid : 3d array Returns a 3d array with every pixel containing an l and m value """ dec = np.deg2rad(self.dec.cpu().numpy()).astype(np.float128) rd = self.rd.cpu().numpy().astype(np.float128) lm_grid = np.zeros(rd.shape, dtype=np.float128) lm_grid[..., 0] = np.cos(rd[..., 1]) * np.sin(rd[..., 0]) lm_grid[..., 1] = np.sin(rd[..., 1]) * np.cos(dec) - np.cos( rd[..., 1] ) * np.sin(dec) * np.cos(rd[..., 0]) return torch.from_numpy(lm_grid.astype(np.float64)).to(self.device)
[docs] def calc_direction_cosines( self, ha: torch.tensor, el_st: torch.tensor, delta_x: torch.tensor, delta_y: torch.tensor, delta_z: torch.tensor, ): """Calculates direction cosines u, v, and w for given hour angles and relative antenna positions. Parameters ---------- ha : :func:`torch.tensor` Tensor containing hour angles for each time step. el_st : :func:`torch.tensor` Tensor containing station elevations. delta_x : :func:`torch.tensor` Tensor containing relative antenna x-postions. delta_y : :func:`torch.tensor` Tensor containing relative antenna y-postions. delta_z : :func:`torch.tensor` Tensor containing relative antenna z-postions. Returns ------- u : :func:`torch.tensor` Tensor containing direction cosines in u-axis direction. v : :func:`torch.tensor` Tensor containing direction cosines in v-axis direction. w : :func:`torch.tensor` Tensor containing direction cosines in w-axis direction. """ src_dec = torch.deg2rad(self.dec) ha = torch.deg2rad(ha) u = (torch.sin(ha) * delta_x + torch.cos(ha) * delta_y).reshape(-1) v = ( -torch.sin(src_dec) * torch.cos(ha) * delta_x + torch.sin(src_dec) * torch.sin(ha) * delta_y + torch.cos(src_dec) * delta_z ).reshape(-1) w = ( torch.cos(src_dec) * torch.cos(ha) * delta_x - torch.cos(src_dec) * torch.sin(ha) * delta_y + torch.sin(src_dec) * delta_z ).reshape(-1) if not (u.shape == v.shape == w.shape): raise ValueError( "Expected u, v, and w to have the same shapes " f"but got {u.shape}, {v.shape}, and {w.shape}." ) return u, v, w