PTWriter#

class pyvisgen.io.datawriters.PTWriter(output_path: Path, dataset_type: str, amp_phase: bool, half_image: bool = True, **kwargs)[source]#

Bases: DataWriter

DataWriter class for saving data in PyTorch (.pt) format.

Creates a new .pt file for each sample with pattern samp_{dataset_type}_{index}.pt. The input arrays are cropped to half their height (with overlap pixel overlap) and validated before writing.

Parameters:
output_pathPath

Directory path where .pt files will be written.

dataset_typestr

Type of dataset being written (e.g., ‘train’, ‘test’, ‘validation’).

amp_phasebool

If True, metadata TYPE key will contain ‘amp_phase”, otherwise ‘real_imag’.

Examples

>>> writer = PTWriter(output_path="./data", dataset_type="train", amp_phase=True)
>>> writer.write(x_data, y_data, index=0)

Or as a context manager:

>>> rng = np.random.default_rng()
>>>
>>> with PTWriter(
...     output_path="./data", dataset_type="train", amp_phase=True
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, bundle_length=len(x_data))

Methods Summary

get_half_image(x, y[, overlap])

Extract half height of every image with a small overlap.

test_shapes(array, name)

Validate the shape of input arrays.

write(x, y, *, index, bundle_length[, ...])

Write data bundles to individual PyTorch (.pt) files.

Methods Documentation

get_half_image(x: ndarray, y: ndarray, overlap: int = 5) tuple[ndarray]#

Extract half height of every image with a small overlap.

Parameters:
xnp.ndarray

Simulated data array with shape (B, C, H, W).

ynp.ndarray

Ground truth array with shape (B, C, H, W).

Returns:
tuple[np.ndarray, np.ndarray]

Tuple containing the cropped x and y arrays.

test_shapes(array: ndarray, name: str) None#

Validate the shape of input arrays.

Arrays should have the shape (B, C, H, W), where B is the batch size, C the number of channels (2), and W and H the width and height of the images.

Parameters:
arraynp.ndarray

Array to validate.

namestr

Name of the array for error reporting.

Raises:
ValueError

If array axis 1 is not size 2.

ValueError

If array does not have exactly 4 dimensions.

write(x: ndarray | Tensor, y: ndarray | Tensor, *, index, bundle_length: int, overlap: int = 5, name_x: str = 'X', name_y: str = 'y', **kwargs) None[source]#

Write data bundles to individual PyTorch (.pt) files.

The input arrays are cropped to half their height (with overlap pixel overlap) and validated before writing as sparse tensors to .pt files.

Parameters:
xnp.ndarray

First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

ynp.ndarray

Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.

indexint

Bundle index used in the output filename.

bundle_lengthint

Number of samples to write in this bundle.

overlapint, optional

Overlap parameter for extracting half-images. Default: 5.

name_xstr, optional

Key of the dataset for x array in the HDF5 file. Default: "X".

name_ystr, optional

Key of the dataset for y array in the HDF5 file. Default: "y".

Examples

>>> rng = np.random.default_rng()
>>>
>>> with PTWriter(
...     output_path="./data", dataset_type="train", amp_phase=True
... ) as writer:
...     x_data = rng.uniform(size=(5, 10, 2, 256, 256))
...     y_data = rng.uniform(size=(5, 10, 2, 256, 256))
...
...     for bundle_id, (x, y) in enumerate(zip(x_data, y_data)):
...         writer.write(x, y, index=bundle_id, bundle_length=len(x))