Source code for pyvisgen.io.datawriters

import os
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Self

import numpy as np
import torch
from h5py import File

from pyvisgen.fits.writer import create_hdu_list

try:
    import pyarrow as pa
    import pyarrow.parquet as pq
    import webdataset as wds

    _WDS_AVAIL = True
except ImportError:
    _WDS_AVAIL = False


__all__ = [
    "DataWriter",
    "FITSWriter",
    "H5Writer",
    "PTWriter",
    "UVH5Writer",
    "WDSShardWriter",
]


[docs] class DataWriter(ABC): """Abstract base class for data writers in pyvisgen. This class contains methods to get half images and test the shapes of the data prior to writing. It also supports a context manager protocol. Subclasses must implement the `__init__` and `write` methods to define writing behavior. Parameters ---------- output_path : str or Path Path where the dataset will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). *args Variable length argument list passed to subclass implementations. **kwargs Arbitrary keyword arguments passed to subclass implementations. Examples -------- >>> class MyWriter(DataWriter): ... def __init__(self, output_path, dataset_type): ... self.output_path = output_path ... self.dataset_type = dataset_type ... ... def write(self, data): ... # Implementation here ... pass >>> >>> with MyWriter("output_file", dataset_type="train") as writer: ... writer.write(data) """ @abstractmethod def __init__( self, output_path: Path, dataset_type: str, *args, **kwargs ) -> None: # pragma: no cover """Initialize the data writer. This method must be implemented by subclasses to handle the setup of the context manager. Parameters ---------- output_path : str or Path Path where the dataset will be written. dataset_type : str Type of dataset being written. *args Additional positional arguments for subclass-specific initialization. **kwargs Additional keyword arguments for subclass-specific initialization. """ ...
[docs] @abstractmethod def write(self, *args, **kwargs) -> None: # pragma: no cover """Write data to the output destination. This method must be implemented by subclasses to handle the actual writing of data to the specified output format. Parameters ---------- *args Data and parameters required for writing, defined by subclass. **kwargs Additional options for writing, defined by subclass. """ ...
[docs] def test_shapes(self, array: np.ndarray, name: str) -> None: """Validate the shape of input arrays. Arrays should have the shape (B, C, H, W), where B is the batch size, C the number of channels (2), and W and H the width and height of the images. Parameters ---------- array : np.ndarray Array to validate. name : str Name of the array for error reporting. Raises ------ ValueError If array axis 1 is not size 2. ValueError If array does not have exactly 4 dimensions. """ if array.shape[1] != 2: raise ValueError( f"Expected array {name} axis 1 to be 2 but got " f"{array.shape} with axis 1: {array.shape[1]}! " "This usually indicates that the images do not have " "separate channels for amplitude/phase or real/imaginary " "data." ) if array.ndim != 4: raise ValueError( f"Expected array {name} ndim to be 4 but got " f"{array.shape} with ndim {array.ndim}!" )
[docs] def get_half_image( self, x: np.ndarray, y: np.ndarray, overlap: int = 5 ) -> tuple[np.ndarray]: """Extract half height of every image with a small overlap. Parameters ---------- x : np.ndarray Simulated data array with shape (B, C, H, W). y : np.ndarray Ground truth array with shape (B, C, H, W). Returns ------- tuple[np.ndarray, np.ndarray] Tuple containing the cropped x and y arrays. """ half_image = x.shape[2] // 2 x = x[:, :, : half_image + overlap, :] y = y[:, :, : half_image + overlap, :] return x, y
def __enter__(self) -> Self: """Enter the context manager. Returns ------- Self The DataWriter instance itself. """ return self def __exit__(self, exc_type, exc_value, traceback) -> None: """Exit the context manager. Performs cleanup when exiting the context. Default implementation does nothing; subclasses can override to add cleanup logic. Parameters ---------- exc_type : type or None The type of exception that occurred, if any. exc_value : Exception or None The exception instance that occurred, if any. traceback : traceback or None The traceback object for the exception, if any. Returns ------- None Returns ``None`` per default. """ return None
[docs] class H5Writer(DataWriter): """HDF5 file writer for pyvisgen datasets. This writer saves data arrays to HDF5 files using the h5py library. Each sample is written to a separate ``.h5`` file. The writer automatically crops images to half their height with a small overlap and validates array shapes before writing. Parameters ---------- output_path : str or Path Directory path where HDF5 files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). This is used in the output filename pattern. Examples -------- >>> writer = H5Writer(output_path="./data", dataset_type="train") >>> writer.write(x_data, y_data, index=0) Or as a context manager: >>> rng = np.random.default_rng() >>> >>> with H5Writer(output_path="./data", dataset_type="train") as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id) """ def __init__( self, output_path: Path, dataset_type: str, half_image: bool = True, **kwargs ) -> None: """Initialize the HDF5 writer. Parameters ---------- output_path : str or Path Directory path where HDF5 files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). """ self.output_path = output_path self.dataset_type = dataset_type self.half_image = half_image os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
[docs] def write( self, x, y, index, name_x="x", name_y="y", overlap: int = 5, **kwargs, ) -> None: """Write FFT pair data to an HDF5 file. Creates a new HDF5 file for each sample with pattern ``samp_{dataset_type}_{index}.h5``. The input arrays are cropped to half their height (with 5 pixel overlap) and validated before writing. Parameters ---------- x : np.ndarray First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2. y : np.ndarray Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2. index : int Bundle index used in the output filename. overlap : int, optional Overlap parameter for extracting half-images. Default: 5. name_x : str, optional Key of the dataset for x array in the HDF5 file. Default: ``"x"``. name_y : str, optional Key of the dataset for y array in the HDF5 file. Default: ``"y"``. Raises ------ ValueError If x or y arrays don't have the expected shape (4 dimensions with axis 1 of size 2). Examples -------- >>> rng = np.random.default_rng() >>> >>> with H5Writer(output_path="./data", dataset_type="train") as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id) """ output_file = self.output_path / Path( f"samp_{self.dataset_type}_" + str(index) + ".h5" ) if self.half_image: x, y = self.get_half_image(x, y, overlap=overlap) self.test_shapes(x, "x") self.test_shapes(y, "y") with File(output_file, "w") as f: f.create_dataset(name_x, data=x) f.create_dataset(name_y, data=y)
[docs] class FITSWriter(DataWriter): """FITS file writer for pyvisgen visibility datasets. This writer saves visibility data and observation information to FITS (Flexible Image Transport System) files. Each sample is written to a separate ``.fits`` file. Parameters ---------- output_path : str or Path Directory path where FITS files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). This is used in the file names. Examples -------- >>> writer = FITSWriter(output_path="./data") >>> writer.write(vis_data, obs, index=0) Or as a context manager: >>> with FITSWriter(output_path="./data") as writer: ... writer.write(vis_data, obs, index=0) """ def __init__(self, output_path: Path, dataset_type: str, **kwargs) -> None: """Initialize the FITS writer. Parameters ---------- output_path : str or Path Directory path where FITS files will be written. """ self.output_path = output_path self.dataset_type = dataset_type
[docs] def write( self, vis_data, obs, index, overwrite=True, **kwargs, ) -> None: """Write visibility data and observation metadata to a FITS file. Creates a new FITS file for each sample with pattern ``vis_{dataset_type}_{index}.fits``. Parameters ---------- vis_data : array-like Visibility data to be written to the FITS file. obs : object Observation metadata object from :class:`~pyvisgen.simulation.Observation`. index : int Sample index used in the output filename. overwrite : bool, optional If ``True``, overwrite the output file if it already exists, otherwise an error is raised. Default: ``True``. See Also -------- pyvisgen.fits.writer.create_hdu_list : For more information on the parameters. Examples -------- >>> writer = FITSWriter(output_path="./data") >>> writer.write(vis, obs, index=0) >>> # Creates file: ./data/vis_train_0.fits >>> writer.write(vis, obs, index=1, overwrite=False) >>> # Creates file: ./data/vis_train_1.fits (raises error if exists) """ output_file = self.output_path / Path( f"vis_{self.dataset_type}_" + str(index) + ".fits" ) hdu_list = create_hdu_list(vis_data, obs) hdu_list.writeto(output_file, overwrite=overwrite)
[docs] class UVH5Writer(DataWriter): """HDF5 file writer for UV-plane simulation data. This writer saves visibilities, UVW coordinates, LMN coordinates, and the simulated sky to a single HDF5 file per sample. The file layout is:: {dataset_type}_{index}.uvh5 ├── visibilities/ │ ├── V_11 (complex128) │ ├── V_22 (complex128) │ ├── V_12 (complex128) │ ├── V_21 (complex128) │ └── weights (float64) ├── uvw/ │ ├── u │ ├── v │ ├── w │ └── st_id_pairs (int64, shape n_baselines x 2) ├── lmn/ │ ├── l │ ├── m │ └── n ├── frequency_bands ├── channel_widths ├── normalize ├── times └── sky/ └── SI Parameters ---------- output_path : str or Path Directory path where HDF5 files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). Used in the output filename pattern. Examples -------- >>> writer = UVH5Writer(output_path="./data", dataset_type="train") >>> writer.write(vis_data, obs, index=0, sky=SI) Or as a context manager: >>> with UVH5Writer(output_path="./data", dataset_type="train") as writer: ... writer.write(vis_data, obs, index=0, sky=SI) """ def __init__(self, output_path: Path, dataset_type: str, **kwargs) -> None: """Initialize the UVH5 writer. Parameters ---------- output_path : str or Path Directory path where HDF5 files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). """ self.output_path = output_path self.dataset_type = dataset_type os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
[docs] def write( self, vis_data, obs, index: int, sky=None, normalize: bool = True, **kwargs, ) -> None: """Write simulation data to an HDF5 file. Creates a new HDF5 file for each sample with pattern ``uvh5_{dataset_type}_{index}.uvh5``. Parameters ---------- vis_data : Visibilities Visibilities dataclass object from :func:`~pyvisgen.simulation.visibility.vis_loop`, containing V_11, V_22, V_12, V_21, u, v, w, and related tensors. obs : Observation Observation object from :class:`~pyvisgen.simulation.Observation`. Used to retrieve the LMN coordinate grid via ``obs.lm``. index : int Sample index used in the output filename. sky : torch.Tensor or np.ndarray, optional Sky intensity distribution (SI) passed to the visibility simulation, with shape ``(C, H, W)``. If ``None``, the ``sky/`` group is omitted from the output file. Examples -------- >>> writer = UVH5Writer(output_path="./data", dataset_type="train") >>> writer.write(vis_data, obs, index=0, sky=SI) """ output_file = self.output_path / Path(f"{self.dataset_type}_{index}.uvh5") lm = self.__to_numpy(obs.lm) # shape (H, W, 2) n = np.sqrt(np.maximum(1.0 - lm[..., 0] ** 2 - lm[..., 1] ** 2, 0.0)) with File(output_file, "w") as f: vis_grp = f.create_group("visibilities") vis_grp.create_dataset("V_11", data=self.__to_numpy(vis_data.V_11)) vis_grp.create_dataset("V_22", data=self.__to_numpy(vis_data.V_22)) vis_grp.create_dataset("V_12", data=self.__to_numpy(vis_data.V_12)) vis_grp.create_dataset("V_21", data=self.__to_numpy(vis_data.V_21)) vis_grp.create_dataset("weights", data=self.__to_numpy(vis_data.weights)) uvw_grp = f.create_group("uvw") uvw_grp.create_dataset("u", data=self.__to_numpy(vis_data.u)) uvw_grp.create_dataset("v", data=self.__to_numpy(vis_data.v)) uvw_grp.create_dataset("w", data=self.__to_numpy(vis_data.w)) uvw_grp.create_dataset( "st_id_pairs", data=self.__to_numpy(vis_data.st_id_pairs) ) lmn_grp = f.create_group("lmn") lmn_grp.create_dataset("l", data=lm[..., 0]) lmn_grp.create_dataset("m", data=lm[..., 1]) lmn_grp.create_dataset("n", data=n) freq_bands = self.__to_numpy(obs.ref_frequency + obs.frequency_offsets) f.create_dataset("frequency_bands", data=freq_bands) f.create_dataset("channel_widths", data=self.__to_numpy(obs.bandwidths)) f.create_dataset("normalize", data=np.bool_(normalize)) times = self.__to_numpy(vis_data.date) f.create_dataset("times", data=times) obs_grp = f.create_group("obs") obs_grp.create_dataset("ra", data=self.__to_numpy(obs.ra)) obs_grp.create_dataset("dec", data=self.__to_numpy(obs.dec)) obs_grp.create_dataset("layout", data=obs.layout) if sky is not None: sky_grp = f.create_group("sky") sky_grp.create_dataset("SI", data=self.__to_numpy(sky))
def __to_numpy(self, t: torch.Tensor) -> torch.Tensor: if isinstance(t, torch.Tensor): return t.detach().cpu().numpy() return np.asarray(t)
[docs] class WDSShardWriter(DataWriter): """WebDataset file writer for pyvisgen datasets. This writer saves data arrays to .tar(.gz) files using the WebDataset library. Each bundle is written to a separate .tar file. The writer automatically crops images to half their height with a small overlap and validates array shapes before writing. Parameters ---------- output_path : str or Path Directory path where .tar files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). This is used in the file names and shard patterns. shard_pattern : str Format string for naming shard files. Should include a format specifier for the shard index (e.g., "%06d.tar"). The write() method will automatically add ``dataset_type`` to the shard name (e.g., "train-%06.tar"). amp_phase : bool If ``True``, saves "amp_phase" to the .parquet metadata files; if ``False``, saves "real_imag" instead. compress : bool, optional If ``True``, compresses shards using gzip compression. Default is False. Automatically appends '.gz' to the shard pattern. **kwargs Additional keyword arguments for compatibility with other writers. Examples -------- >>> writer = WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06d.tar", ... ) >>> writer.write(x_data, y_data, index=0) Or as a context manager: >>> rng = np.random.default_rng() >>> >>> with WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06.tar", ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, overlap=5) """ def __init__( self, output_path: str | Path, *, dataset_type: str, shard_pattern: str, amp_phase: bool, compress: bool = False, total_samples: int, half_image: bool = True, **kwargs, ) -> None: """Initializes the WebDataset writer. Parameters ---------- output_path : str or Path Directory path where .tar files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). This is used in the file names and shard patterns. shard_pattern : str Format string for naming shard files. Should include a format specifier for the shard index (e.g., "%06d.tar"). The write() method will automatically add ``dataset_type`` to the shard name (e.g., "train-%06.tar"). amp_phase : bool If ``True``, saves "amp_phase" to the .parquet metadata files; if ``False``, saves "real_imag" instead. compress : bool, optional If ``True``, compresses shards using gzip compression. Default is False. Automatically appends '.gz' to the shard pattern. **kwargs Additional keyword arguments for compatibility with other writers. """ if not _WDS_AVAIL: raise ImportError( "Could not import webdataset. Please make sure you install " "pyvisgen with the webdataset extra: " "uv pip install pyvisgen[webdataset]" ) if not isinstance(output_path, Path): output_path = Path(output_path) self.output_path = output_path self.dataset_type = dataset_type self.shard_pattern = shard_pattern self.compress = compress self.half_image = half_image self.total_samples = total_samples if amp_phase: self.data_type = "amp_phase" else: self.data_type = "real_imag" if self.compress and not shard_pattern.endswith(".gz"): self.shard_pattern = self.shard_pattern.replace(".tar", ".tar.gz") # keeping track of IDs self.current_shard_id = 0 self.total_samples_written = 0 self.shards_written = 0
[docs] def write( self, x: np.ndarray, y: np.ndarray, index: int, overlap=5, **kwargs, ) -> None: """Write data bundles to individual .tar(.gz) files. The input arrays are cropped to half their height (with ``overlap`` pixel overlap) and validated before writing to .npy files inside the .tar archives. Parameters ---------- x : np.ndarray First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2. y : np.ndarray Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2. index : int Bundle index used in the output filename. overlap : int, optional Overlap parameter for extracted half-images. Default: 5. Examples -------- >>> writer = WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06d.tar", ... ) >>> writer.write(x_data, y_data, index=0) Or as a context manager: >>> rng = np.random.default_rng() >>> >>> with WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06.tar", ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, overlap=5) """ bundle_length = x.shape[0] filename = ( self.dataset_type + "-" + (self.shard_pattern % self.current_shard_id) ) shard_path = str(self.output_path / filename) if self.half_image: inputs, targets = self.get_half_image(x, y, overlap=overlap) else: inputs, targets = x, y self.test_shapes(inputs, "x") self.test_shapes(targets, "y") with wds.TarWriter(shard_path, compress=self.compress) as tarwriter: for x, y in zip(inputs, targets): sample = { "__key__": f"{self.dataset_type}_{self.total_samples_written:08d}", "input.npy": self._serialize_numpy(x), "target.npy": self._serialize_numpy(y), } tarwriter.write(sample) self.total_samples_written += 1 metadict = { "total_samples_in_dataset": [self.total_samples], "samples_in_shard": [bundle_length], "shard_idx": [self.current_shard_id], "bundle_id": [index], "data_type": [self.data_type], } metadata = pa.Table.from_pydict(metadict) metadata_path = ( f"{shard_path}".replace(".tar", ".parquet") if shard_path.endswith(".tar") else f"{shard_path}".replace(".tar.gz", ".parquet") ) pq.write_table(metadata, metadata_path) self.current_shard_id += 1 self.shards_written += 1
def _serialize_numpy(self, array: np.ndarray) -> bytes: buffer = BytesIO() np.save(buffer, array) return buffer.getvalue()
[docs] class PTWriter(DataWriter): """DataWriter class for saving data in PyTorch (.pt) format. Creates a new .pt file for each sample with pattern ``samp_{dataset_type}_{index}.pt``. The input arrays are cropped to half their height (with ``overlap`` pixel overlap) and validated before writing. Parameters ---------- output_path : Path Directory path where .pt files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). amp_phase : bool If True, metadata ``TYPE`` key will contain 'amp_phase", otherwise 'real_imag'. Examples -------- >>> writer = PTWriter(output_path="./data", dataset_type="train", amp_phase=True) >>> writer.write(x_data, y_data, index=0) Or as a context manager: >>> rng = np.random.default_rng() >>> >>> with PTWriter( ... output_path="./data", dataset_type="train", amp_phase=True ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, bundle_length=len(x_data)) """ def __init__( self, output_path: Path, dataset_type: str, amp_phase: bool, half_image: bool = True, **kwargs, ) -> None: """Initialize the PT writer. Parameters ---------- output_path : str or Path Directory path where .pt files will be written. dataset_type : str Type of dataset being written (e.g., 'train', 'test', 'validation'). amp_phase : bool If True, metadata key ``TYPE`` will contain 'amp_phase', otherwise 'real_imag'. """ self.output_path = output_path self.dataset_type = dataset_type self.half_image = half_image if amp_phase: self.data_type = "amp_phase" else: self.data_type = "real_imag"
[docs] def write( self, x: np.ndarray | torch.Tensor, y: np.ndarray | torch.Tensor, *, index, bundle_length: int, overlap: int = 5, name_x: str = "X", name_y: str = "y", **kwargs, ) -> None: """Write data bundles to individual PyTorch (.pt) files. The input arrays are cropped to half their height (with ``overlap`` pixel overlap) and validated before writing as sparse tensors to .pt files. Parameters ---------- x : np.ndarray First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2. y : np.ndarray Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2. index : int Bundle index used in the output filename. bundle_length : int Number of samples to write in this bundle. overlap : int, optional Overlap parameter for extracting half-images. Default: 5. name_x : str, optional Key of the dataset for x array in the HDF5 file. Default: ``"X"``. name_y : str, optional Key of the dataset for y array in the HDF5 file. Default: ``"y"``. Examples -------- >>> rng = np.random.default_rng() >>> >>> with PTWriter( ... output_path="./data", dataset_type="train", amp_phase=True ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, bundle_length=len(x)) """ if self.half_image: x, y = self.get_half_image(x, y, overlap=overlap) if isinstance(x, np.ndarray): x = torch.from_numpy(x) if isinstance(y, np.ndarray): y = torch.from_numpy(y) self.test_shapes(x, "X") self.test_shapes(y, "y") x = x[:, 0] + 1j * x[:, 1] y = y[:, 0] + 1j * y[:, 1] for i in range(bundle_length): output_file = self.output_path / Path( f"samp_{self.dataset_type}_{index * bundle_length + i}.pt" ) torch.save( obj={"SIM": x[i].to_sparse(), "TRUTH": y[i], "TYPE": self.data_type}, f=output_file, )