Source code for pyvisgen.utils.data

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import h5py
import numpy as np
from natsort import natsorted

if TYPE_CHECKING:
    from typing import String, Union

__all__ = ["load_bundles", "open_bundles"]


[docs] def load_bundles( data_path: Union(String, Path), dataset_type: Union(String, None) = None ) -> list: """Loads bundle paths, filters for HDF5 files, and returns them in a naturally ordered list. Parameters ---------- data_path : str or Path Path to the directory containing the HDF5 files. dataset_type : str, optional Type of the dataset to filter, e.g. 'train', 'valid', or 'test'. Returns ------- bundles : list Naturally ordered list containing paths to HDF5 files. """ if isinstance(data_path, str): data_path = Path(data_path) if not dataset_type: filter_fmt = "*.h5" else: filter_fmt = f"*{dataset_type}*.h5" bundles = natsorted(list(data_path.glob(filter_fmt))) return bundles
[docs] def open_bundles(path: str | Path, key: str = "y") -> np.array: """Opens a bundle HDF5 file. Parameters ---------- path : str or Path Path to the bundle HDF5 file. Returns ------- bundle_y : :func:`~numpy.array` :func:`~numpy.array` containing data from the bundle file. """ f = h5py.File(path, "r") bundle_y = np.array(f[key]) return bundle_y