WDSShardWriter#
- class pyvisgen.io.datawriters.WDSShardWriter(output_path: str | Path, *, dataset_type: str, total_samples: int, shard_pattern: str, amp_phase: bool, compress: bool = False, **kwargs)[source]#
Bases:
DataWriterWebDataset file writer for pyvisgen datasets.
This writer saves data arrays to .tar(.gz) files using the WebDataset library. Each bundle is written to a separate .tar file. The writer automatically crops images to half their height with a small overlap and validates array shapes before writing.
- Parameters:
- output_pathstr or Path
Directory path where .tar files will be written.
- dataset_typestr
Type of dataset being written (e.g., ‘train’, ‘test’, ‘validation’). This is used in the file names and shard patterns.
- shard_patternstr
Format string for naming shard files. Should include a format specifier for the shard index (e.g., “%06d.tar”). The write() method will automatically add
dataset_typeto the shard name (e.g., “train-%06.tar”).- amp_phasebool
If
True, saves “amp_phase” to the .parquet metadata files; ifFalse, saves “real_imag” instead.- compressbool, optional
If
True, compresses shards using gzip compression. Default is False. Automatically appends ‘.gz’ to the shard pattern.- **kwargs
Additional keyword arguments for compatibility with other writers.
Examples
>>> writer = WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06d.tar", ... ) >>> writer.write(x_data, y_data, index=0)
Or as a context manager:
>>> rng = np.random.default_rng() >>> >>> with WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06.tar", ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, overlap=5)
Methods Summary
get_half_image(x, y[, overlap])Extract half height of every image with a small overlap.
test_shapes(array, name)Validate the shape of input arrays.
write(x, y, index[, overlap])Write data bundles to individual .tar(.gz) files.
Methods Documentation
- get_half_image(x: ndarray, y: ndarray, overlap: int = 5) tuple[ndarray]#
Extract half height of every image with a small overlap.
- Parameters:
- xnp.ndarray
Simulated data array with shape (B, C, H, W).
- ynp.ndarray
Ground truth array with shape (B, C, H, W).
- Returns:
- tuple[np.ndarray, np.ndarray]
Tuple containing the cropped x and y arrays.
- test_shapes(array: ndarray, name: str) None#
Validate the shape of input arrays.
Arrays should have the shape (B, C, H, W), where B is the batch size, C the number of channels (2), and W and H the width and height of the images.
- Parameters:
- arraynp.ndarray
Array to validate.
- namestr
Name of the array for error reporting.
- Raises:
- ValueError
If array axis 1 is not size 2.
- ValueError
If array does not have exactly 4 dimensions.
- write(x: ndarray, y: ndarray, index: int, overlap=5, **kwargs) None[source]#
Write data bundles to individual .tar(.gz) files.
The input arrays are cropped to half their height (with
overlappixel overlap) and validated before writing to .npy files inside the .tar archives.- Parameters:
- xnp.ndarray
First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.
- ynp.ndarray
Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.
- indexint
Bundle index used in the output filename.
- overlapint, optional
Overlap parameter for extracted half-images. Default: 5.
Examples
>>> writer = WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06d.tar", ... ) >>> writer.write(x_data, y_data, index=0)
Or as a context manager:
>>> rng = np.random.default_rng() >>> >>> with WDSShardWriter( ... output_path="./data", ... dataset_type="train", ... total_samples=total_samples, ... shard_pattern="train-%06.tar", ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, overlap=5)