WDSShardWriter#

class pyvisgen.io.datawriters.WDSShardWriter(output_path: str | Path, *, dataset_type: str, total_samples: int, shard_pattern: str, amp_phase: bool, compress: bool = False, **kwargs)[source]#

Bases: DataWriter

WebDataset file writer for pyvisgen datasets.

This writer saves data arrays to .tar(.gz) files using the WebDataset library. Each bundle is written to a separate .tar file. The writer automatically crops images to half their height with a small overlap and validates array shapes before writing.

Parameters:
output_pathstr or Path

Directory path where .tar files will be written.

dataset_typestr

Type of dataset being written (e.g., ‘train’, ‘test’, ‘validation’). This is used in the file names and shard patterns.

shard_patternstr

Format string for naming shard files. Should include a format specifier for the shard index (e.g., “%06d.tar”). The write() method will automatically add dataset_type to the shard name (e.g., “train-%06.tar”).

amp_phasebool

If True, saves “amp_phase” to the .parquet metadata files; if False, saves “real_imag” instead.

compressbool, optional

If True, compresses shards using gzip compression. Default is False. Automatically appends ‘.gz’ to the shard pattern.

**kwargs

Additional keyword arguments for compatibility with other writers.

Examples

>>> writer = WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06d.tar",
... )
>>> writer.write(x_data, y_data, index=0)

Or as a context manager:

>>> rng = np.random.default_rng()
>>>
>>> with WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06.tar",
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, overlap=5)

Methods Summary

get_half_image(x, y[, overlap])

Extract half height of every image with a small overlap.

test_shapes(array, name)

Validate the shape of input arrays.

write(x, y, index[, overlap])

Write data bundles to individual .tar(.gz) files.

Methods Documentation

get_half_image(x: ndarray, y: ndarray, overlap: int = 5) tuple[ndarray]#

Extract half height of every image with a small overlap.

Parameters:
xnp.ndarray

Simulated data array with shape (B, C, H, W).

ynp.ndarray

Ground truth array with shape (B, C, H, W).

Returns:
tuple[np.ndarray, np.ndarray]

Tuple containing the cropped x and y arrays.

test_shapes(array: ndarray, name: str) None#

Validate the shape of input arrays.

Arrays should have the shape (B, C, H, W), where B is the batch size, C the number of channels (2), and W and H the width and height of the images.

Parameters:
arraynp.ndarray

Array to validate.

namestr

Name of the array for error reporting.

Raises:
ValueError

If array axis 1 is not size 2.

ValueError

If array does not have exactly 4 dimensions.

write(x: ndarray, y: ndarray, index: int, overlap=5, **kwargs) None[source]#

Write data bundles to individual .tar(.gz) files.

The input arrays are cropped to half their height (with overlap pixel overlap) and validated before writing to .npy files inside the .tar archives.

Parameters:
xnp.ndarray

First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

ynp.ndarray

Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

indexint

Bundle index used in the output filename.

overlapint, optional

Overlap parameter for extracted half-images. Default: 5.

Examples

>>> writer = WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06d.tar",
... )
>>> writer.write(x_data, y_data, index=0)

Or as a context manager:

>>> rng = np.random.default_rng()
>>>
>>> with WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06.tar",
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, overlap=5)