WDSShardWriter#

class pyvisgen.io.WDSShardWriter(output_path: str | Path, *, dataset_type: str, total_samples: int, shard_pattern: str, amp_phase: bool, compress: bool = False, **kwargs)[source]#

Bases: DataWriter

WebDataset file writer for pyvisgen datasets.

This writer saves data arrays to .tar(.gz) files using the WebDataset library. Each bundle is written to a separate .tar file. The writer automatically crops images to half their height with a small overlap and validates array shapes before writing.

Parameters:
output_pathstr or Path

Directory path where .tar files will be written.

dataset_typestr

Type of dataset being written (e.g., ‘train’, ‘test’, ‘validation’). This is used in the file names and shard patterns.

shard_patternstr

Format string for naming shard files. Should include a format specifier for the shard index (e.g., “%06d.tar”). The write() method will automatically add dataset_type to the shard name (e.g., “train-%06.tar”).

amp_phasebool

If True, saves “amp_phase” to the .parquet metadata files; if False, saves “real_imag” instead.

compressbool, optional

If True, compresses shards using gzip compression. Default is False. Automatically appends ‘.gz’ to the shard pattern.

**kwargs

Additional keyword arguments for compatibility with other writers.

Examples

>>> writer = WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06d.tar",
... )
>>> writer.write(x_data, y_data, index=0)

Or as a context manager:

>>> rng = np.random.default_rng()
>>>
>>> with WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06.tar",
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, overlap=5)

Methods Summary

write(x, y, index[, overlap])

Write data bundles to individual .tar(.gz) files.

Methods Documentation

write(x: ndarray, y: ndarray, index: int, overlap=5, **kwargs) None[source]#

Write data bundles to individual .tar(.gz) files.

The input arrays are cropped to half their height (with overlap pixel overlap) and validated before writing to .npy files inside the .tar archives.

Parameters:
xnp.ndarray

First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

ynp.ndarray

Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

indexint

Bundle index used in the output filename.

overlapint, optional

Overlap parameter for extracted half-images. Default: 5.

Examples

>>> writer = WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06d.tar",
... )
>>> writer.write(x_data, y_data, index=0)

Or as a context manager:

>>> rng = np.random.default_rng()
>>>
>>> with WDSShardWriter(
...     output_path="./data",
...     dataset_type="train",
...     total_samples=total_samples,
...     shard_pattern="train-%06.tar",
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, overlap=5)