PTWriter#
- class pyvisgen.io.datawriters.PTWriter(output_path: Path, dataset_type: str, amp_phase: bool, **kwargs)[source]#
Bases:
DataWriterDataWriter class for saving data in PyTorch (.pt) format.
Creates a new .pt file for each sample with pattern
samp_{dataset_type}_{index}.pt. The input arrays are cropped to half their height (withoverlappixel overlap) and validated before writing.- Parameters:
- output_pathPath
Directory path where .pt files will be written.
- dataset_typestr
Type of dataset being written (e.g., ‘train’, ‘test’, ‘validation’).
- amp_phasebool
If True, metadata
TYPEkey will contain ‘amp_phase”, otherwise ‘real_imag’.
Examples
>>> writer = PTWriter(output_path="./data", dataset_type="train", amp_phase=True) >>> writer.write(x_data, y_data, index=0)
Or as a context manager:
>>> rng = np.random.default_rng() >>> >>> with PTWriter( ... output_path="./data", dataset_type="train", amp_phase=True ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, bundle_length=len(x_data))
Methods Summary
get_half_image(x, y[, overlap])Extract half height of every image with a small overlap.
test_shapes(array, name)Validate the shape of input arrays.
write(x, y, *, index, bundle_length[, ...])Write data bundles to individual PyTorch (.pt) files.
Methods Documentation
- get_half_image(x: ndarray, y: ndarray, overlap: int = 5) tuple[ndarray]#
Extract half height of every image with a small overlap.
- Parameters:
- xnp.ndarray
Simulated data array with shape (B, C, H, W).
- ynp.ndarray
Ground truth array with shape (B, C, H, W).
- Returns:
- tuple[np.ndarray, np.ndarray]
Tuple containing the cropped x and y arrays.
- test_shapes(array: ndarray, name: str) None#
Validate the shape of input arrays.
Arrays should have the shape (B, C, H, W), where B is the batch size, C the number of channels (2), and W and H the width and height of the images.
- Parameters:
- arraynp.ndarray
Array to validate.
- namestr
Name of the array for error reporting.
- Raises:
- ValueError
If array axis 1 is not size 2.
- ValueError
If array does not have exactly 4 dimensions.
- write(x: ndarray, y: ndarray, *, index, bundle_length: int, overlap: int = 5, name_x: str = 'X', name_y: str = 'y', **kwargs) None[source]#
Write data bundles to individual PyTorch (.pt) files.
The input arrays are cropped to half their height (with
overlappixel overlap) and validated before writing as sparse tensors to .pt files.- Parameters:
- xnp.ndarray
First array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.
- ynp.ndarray
Second array of the FFT pair with shape (batch, 2, height, width). Expected to have 4 dimensions with axis 1 of size 2.
- indexint
Bundle index used in the output filename.
- bundle_lengthint
Number of samples to write in this bundle.
- overlapint, optional
Overlap parameter for extracting half-images. Default: 5.
- name_xstr, optional
Key of the dataset for x array in the HDF5 file. Default:
"X".- name_ystr, optional
Key of the dataset for y array in the HDF5 file. Default:
"y".
Examples
>>> rng = np.random.default_rng() >>> >>> with H5Writer( ... output_path="./data", dataset_type="train", amp_phase=True ... ) as writer: ... x_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... y_data = rng.uniform(size=(5, 10, 2, 256, 256)) ... ... for bundle_id, (x, y) in enumerate(zip(x_data, y_data)): ... writer.write(x, y, index=bundle_id, bundle_length=len(x))