import re
from pathlib import Path
from typing import Self
import h5py
import numpy as np
import torch
from natsort import natsorted
from rich.progress import track
from .datawriters import H5Writer, PTWriter, WDSShardWriter
try:
import pyarrow as pa
import webdataset as wds
_WDS_AVAIL = True
except ImportError:
_WDS_AVAIL = False
__all__ = ["DataConverter"]
def _batch_array(
array: np.ndarray, batch_size: int, return_indices: bool = False
) -> list[np.ndarray]:
"""Splits array into batches of given batch size. Depending on the batch
size, the last array may contain the remainder of elements and may
be smaller batch_size.
Parameters
----------
array : np.ndarray
Array to be batched.
batch_size : int
Batch size for the splits.
return_indices : bool, optional
If ``True``, return indices of splits. Default: ``False``
Returns
-------
list
List of batched arrays.
indices
Indices of splits if return_indices is ``True``.
"""
indices = np.arange(batch_size, len(array), batch_size)
if return_indices:
# also include zero when returning indices
return np.split(array, indices), np.insert(indices, 0, 0)
return np.split(array, indices)
[docs]
class DataConverter:
"""Convert datasets between HDF5, WebDataset, and PyTorch formats.
This class allows loading datasets from various formats
and convert them to a target format. Where available or required,
metadata is read or added to the respective datasets.
Examples
--------
Convert WebDataset to HDF5:
>>> converter = DataConverter.from_wds("./data/visibilities")
>>> converter.to("./data/output", output_format="h5")
Convert HDF5 train split to WebDataset:
>>> converter = DataConverter.from_h5("./data/visibilities", dataset_split="train")
>>> converter.to("~/data/output", output_format="wds", compress=True)
"""
[docs]
@classmethod
def from_wds(cls, data_dir, dataset_split="all") -> Self:
"""Create a DataConverter instance from WebDataset files.
Parameters
----------
data_dir : str or :class:`~pathlib.Path`
Directory containing WebDataset .tar(.gz) files.
dataset_split : str or list
Dataset split to load. If "all", loads train, valid, and test.
Default: ``"all"``
Returns
-------
DataConverter
Configured DataConverter instance with WebDataset source files.
Raises
------
ImportError
If webdataset package is not installed.
"""
if not _WDS_AVAIL:
raise ImportError(
"Could not import webdataset. Please make sure you install "
"pyvisgen with the webdataset extra: "
"uv pip install pyvisgen[webdataset]"
)
cls = cls()
cls._FMT = "wds"
data_dir = Path(data_dir).expanduser().resolve()
dataset_split = cls._get_dataset_split(dataset_split)
cls.datasets = {t: data_dir.glob(f"{t}-*.tar*") for t in dataset_split}
return cls
[docs]
@classmethod
def from_h5(cls, data_dir, dataset_split="all") -> Self:
"""Create a DataConverter instance from HDF5 files.
Parameters
----------
data_dir : str or :class:`~pathlib.Path`
Directory containing HDF5 files.
dataset_split : str or list
Dataset split to load. If "all", loads train, valid, and test.
Default: ``"all"``
Returns
-------
DataConverter
Configured DataConverter instance with HDF5 source files.
"""
cls = cls()
cls._FMT = "h5"
data_dir = Path(data_dir).expanduser().resolve()
dataset_split = cls._get_dataset_split(dataset_split)
cls.datasets = {t: data_dir.glob(f"*{t}_*.h5") for t in dataset_split}
return cls
[docs]
@classmethod
def from_pt(cls, data_dir, dataset_split="all"):
"""Create a DataConverter instance from HDF5 files.
Parameters
----------
data_dir : str or :class:`~pathlib.Path`
Directory containing .pt files.
dataset_split : str or list
Dataset split to load. If "all", loads train, valid, and test.
Default: ``"all"``
Returns
-------
DataConverter
Configured DataConverter instance with PyTorch pickle source files.
"""
cls = cls()
cls._FMT = "pt"
data_dir = Path(data_dir).expanduser().resolve()
dataset_split = cls._get_dataset_split(dataset_split)
cls.datasets = {t: data_dir.glob(f"*{t}_*.pt") for t in dataset_split}
return cls
def _get_dataset_split(self, dataset_split):
if not isinstance(dataset_split, list):
dataset_split = [dataset_split]
if "all" in dataset_split:
dataset_split = ["train", "valid", "test"]
return dataset_split
def _to_h5(self) -> None:
"""Internal method to handle conversion to HDF5 files."""
if self._FMT == "wds":
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to HDF5"
):
with H5Writer(
output_path=self.output_dir,
dataset_type=dataset_type,
half_image=False,
) as writer:
self._handle_wds(files, writer)
elif self._FMT == "pt":
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to HDF5"
):
with H5Writer(
output_path=self.output_dir,
dataset_type=dataset_type,
half_image=False,
) as writer:
self._handle_pt(files, writer)
elif self._FMT == "h5":
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to HDF5"
):
with H5Writer(
output_path=self.output_dir,
dataset_type=dataset_type,
half_image=False,
) as writer:
self._handle_h5(files, writer)
def _to_wds(self) -> None:
"""Internal method to handle conversion to WebDataset files."""
if self._FMT == "h5":
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to WDS"
):
# total_samples is updated after writing all files
total_samples = 0
with WDSShardWriter(
output_path=self.output_dir,
dataset_type=dataset_type,
total_samples=total_samples,
amp_phase=not self.amp_phase
if self.convert_representation
else self.amp_phase,
shard_pattern=self.shard_pattern,
compress=self.compress,
half_image=False,
) as writer:
total_samples = self._handle_h5(files, writer, total_samples)
for file in self.output_dir.glob(f"{dataset_type}*.parquet"):
metadata = pa.parquet.read_table(file).to_pandas()
metadata["total_samples_in_dataset"] = total_samples
table = pa.Table.from_pandas(metadata)
pa.parquet.write_table(table, file)
elif self._FMT == "pt":
# total_samples is updated after writing all files
total_samples = 0
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to WDS"
):
with WDSShardWriter(
output_path=self.output_dir,
dataset_type=dataset_type,
total_samples=total_samples,
amp_phase=not self.amp_phase
if self.convert_representation
else self.amp_phase,
shard_pattern=self.shard_pattern,
compress=self.compress,
half_image=False,
) as writer:
total_samples = self._handle_pt(files, writer, total_samples)
for file in self.output_dir.glob(f"{dataset_type}*.parquet"):
metadata = pa.parquet.read_table(file).to_pandas()
metadata["total_samples_in_dataset"] = [total_samples]
table = pa.Table.from_pandas(metadata)
pa.parquet.write_table(table, file)
elif self._FMT == "wds":
# total_samples is updated after writing all files
total_samples = 0
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to WDS"
):
with WDSShardWriter(
output_path=self.output_dir,
dataset_type=dataset_type,
total_samples=total_samples,
amp_phase=not self.amp_phase
if self.convert_representation
else self.amp_phase,
shard_pattern=self.shard_pattern,
compress=self.compress,
half_image=False,
) as writer:
total_samples = self._handle_wds(files, writer, total_samples)
for file in self.output_dir.glob(f"{dataset_type}*.parquet"):
metadata = pa.parquet.read_table(file).to_pandas()
metadata["total_samples_in_dataset"] = [total_samples]
table = pa.Table.from_pandas(metadata)
pa.parquet.write_table(table, file)
def _to_pt(self):
"""Internal method to handle conversion to PT files."""
if self._FMT == "wds":
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to PT"
):
with PTWriter(
output_path=self.output_dir,
dataset_type=dataset_type,
amp_phase=not self.amp_phase
if self.convert_representation
else self.amp_phase,
half_image=False,
) as writer:
self._handle_wds(files, writer)
elif self._FMT == "h5":
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to PT"
):
with PTWriter(
output_path=self.output_dir,
dataset_type=dataset_type,
amp_phase=not self.amp_phase
if self.convert_representation
else self.amp_phase,
half_image=False,
) as writer:
self._handle_h5(files, writer)
elif self._FMT == "pt":
for dataset_type, files in track(
self.datasets.items(), description="Converting Dataset to PT"
):
with PTWriter(
output_path=self.output_dir,
dataset_type=dataset_type,
amp_phase=not self.amp_phase
if self.convert_representation
else self.amp_phase,
half_image=False,
) as writer:
self._handle_pt(files, writer)
def _handle_wds(self, files, writer, total_samples=0):
for file in track(list(files), description="Processing files..."):
file_idx = re.findall(r"\d+", file.stem)
file_idx = re.sub(r"0+(.+)", r"\1", *file_idx)
webdataset = (
wds.WebDataset(str(file), shardshuffle=False)
.decode()
.to_tuple("input.npy", "target.npy")
)
x = []
y = []
for inp, tar in webdataset:
x.append(inp)
y.append(tar)
x = np.asarray(x)
y = np.asarray(y)
if self.convert_representation:
x = self.convert_repr.convert(torch.from_numpy(x))
y = self.convert_repr.convert(torch.from_numpy(y))
writer.write(x, y, index=int(file_idx), bundle_length=len(x))
total_samples += len(x)
return total_samples
def _handle_h5(self, files, writer, total_samples=0):
for file in track(list(files), description="Processing files..."):
data = h5py.File(file)
file_idx = re.findall(r"\d+", file.stem)
x = np.asarray(data["x"])
y = np.asarray(data["y"])
if self.convert_representation:
x = self.convert_repr.convert(torch.from_numpy(x))
y = self.convert_repr.convert(torch.from_numpy(y))
writer.write(
x,
y,
index=int(file_idx[0]),
bundle_length=len(x),
)
total_samples += len(x)
return total_samples
def _handle_pt(self, files, writer, total_samples=0):
bundles, indices = _batch_array(
np.asarray(natsorted(files)),
self.bundle_size,
return_indices=True,
)
for bundle, index in track(
zip(bundles, indices), description="Processing files..."
):
x = []
y = []
for file in bundle:
data = torch.load(file)
x.append(data["SIM"].to_dense())
y.append(data["TRUTH"])
x = np.asarray(x)
y = np.asarray(y)
x = np.stack((x.real, x.imag), axis=1)
y = np.stack((y.real, y.imag), axis=1)
if self.convert_representation:
x = self.convert_repr.convert(torch.from_numpy(x))
y = self.convert_repr.convert(torch.from_numpy(y))
writer.write(x, y, index=int(index), bundle_length=len(x))
total_samples += len(x)
return total_samples
[docs]
def to(
self,
output_dir: str | Path,
output_format: str = "h5",
amp_phase: bool = True,
shard_pattern: str = "%06d.tar",
compress: bool = True,
bundle_size: int = 100,
convert_representation: bool = False,
) -> None:
"""Convert the loaded dataset to the specified output format.
Parameters
----------
output_dir : str or :class:`~pathlib.Path`
Directory to write converted files to.
output_format : str, optional
Target format for conversion. One of h5, wds or pt.
Default: ``"h5"``
amp_phase : bool, optional
Whether data is in amplitude/phase or real/imaginary
representation. Default: ``True``
shard_pattern : str, optional
Naming pattern for WebDataset shards (only applies to wds output).
Default: ``"%06d.tar"``
compress : bool
Whether to compress WebDataset shards (only applies to wds output).
Default: ``True``
bundle_size : int, optional
Bundle size for HDF5 and WebDataset shards when converting from
PyTorch pickle files. Default: 100
convert_representation : bool, optional
If ``True`` convert from one amplitude/phase representation to
real/imaginary or vice versa. Note, that this requires amp_phase
to match the actual representation in the input data as this
determines which way the conversion will be applied.
Default: False
Raises
------
RuntimeError
If source and target formats are identical.
"""
self.convert_representation = convert_representation
if (
self._FMT.lower() == output_format.lower()
and not self.convert_representation
):
raise RuntimeError(
f"Forbidden: Cannot convert {self._FMT} to h5 if "
"'convert_representation' is set to 'False'. "
"Please make sure that input and output formats are different."
)
self.output_dir = Path(output_dir).expanduser().resolve()
if not self.output_dir.is_dir():
self.output_dir.mkdir(parents=True)
self.amp_phase = amp_phase
self.shard_pattern = shard_pattern
self.compress = compress
self.bundle_size = bundle_size
if self.convert_representation:
if amp_phase is None:
raise ValueError(
"Cannot convert data representation without "
"a valid value for 'amp_phase'. Please set 'amp_phase' "
"to either 'True' or 'False'."
)
self.convert_repr = DataTypeConverter(input_amp_phase=self.amp_phase)
match output_format:
case "h5":
self._to_h5()
case "wds":
self._to_wds()
case "pt":
self._to_pt()
class DataTypeConverter:
def __init__(self, input_amp_phase=True) -> None:
self.input_amp_phase = input_amp_phase
def to_amp_phase(self, data: torch.Tensor) -> torch.Tensor:
real, imag = data[:, 0], data[:, 1]
amp = torch.hypot(real, imag)
phase = torch.atan2(imag, real)
return torch.stack((amp, phase), dim=1)
def to_real_imag(self, data: torch.Tensor) -> torch.Tensor:
amp, phase = data[:, 0], data[:, 1]
real = amp * torch.cos(phase)
imag = amp * torch.sin(phase)
return torch.stack((real, imag), dim=1)
def convert(self, data: torch.Tensor) -> torch.Tensor:
if self.input_amp_phase:
return self.to_real_imag(data)
else:
return self.to_amp_phase(data)