PTWriter#

class pyvisgen.io.PTWriter(output_path: Path, dataset_type: str, amp_phase: bool, half_image: bool = True, **kwargs)[source]#

Bases: DataWriter

DataWriter class for saving data in PyTorch (.pt) format.

Creates a new .pt file for each sample with pattern samp_{dataset_type}_{index}.pt. The input arrays are cropped to half their height (with overlap pixel overlap) and validated before writing.

Parameters:
output_pathPath

Directory path where .pt files will be written.

dataset_typestr

Type of dataset being written (e.g., ‘train’, ‘test’, ‘validation’).

amp_phasebool

If True, metadata TYPE key will contain ‘amp_phase”, otherwise ‘real_imag’.

Examples

>>> writer = PTWriter(output_path="./data", dataset_type="train", amp_phase=True)
>>> writer.write(x_data, y_data, index=0)

Or as a context manager:

>>> rng = np.random.default_rng()
>>>
>>> with PTWriter(
...     output_path="./data", dataset_type="train", amp_phase=True
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, bundle_length=len(x_data))

Methods Summary

write(x, y, *, index, bundle_length[, ...])

Write data bundles to individual PyTorch (.pt) files.

Methods Documentation

write(x: ndarray | Tensor, y: ndarray | Tensor, *, index, bundle_length: int, overlap: int = 5, name_x: str = 'X', name_y: str = 'y', **kwargs) None[source]#

Write data bundles to individual PyTorch (.pt) files.

The input arrays are cropped to half their height (with overlap pixel overlap) and validated before writing as sparse tensors to .pt files.

Parameters:
xnp.ndarray

First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

ynp.ndarray

Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

indexint

Bundle index used in the output filename.

bundle_lengthint

Number of samples to write in this bundle.

overlapint, optional

Overlap parameter for extracting half-images. Default: 5.

name_xstr, optional

Key of the dataset for x array in the HDF5 file. Default: "X".

name_ystr, optional

Key of the dataset for y array in the HDF5 file. Default: "y".

Examples

>>> rng = np.random.default_rng()
>>>
>>> with PTWriter(
...     output_path="./data", dataset_type="train", amp_phase=True
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, bundle_length=len(x))