Source code for pyvisgen.gridding.utils

import os
from pathlib import Path

import h5py
import numpy as np
from tqdm.auto import tqdm

from pyvisgen.fits.data import fits_data
from pyvisgen.gridding import grid_data
from pyvisgen.utils.config import read_data_set_conf
from pyvisgen.utils.data import load_bundles, open_bundles
from pyvisgen.utils.logging import setup_logger

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
LOGGER = setup_logger()


[docs] def create_gridded_data_set(config): conf = read_data_set_conf(config) out_path_fits = Path(conf["out_path_fits"]) out_path = Path(conf["out_path_gridded"]) out_path.mkdir(parents=True, exist_ok=True) sky_dist = load_bundles(conf["in_path"]) fits_files = fits_data(out_path_fits) size = len(fits_files) print(size) # test if conf["num_test_images"] > 0: bundle_test = int(conf["num_test_images"] // conf["bundle_size"]) size -= conf["num_test_images"] for i in tqdm(range(bundle_test)): ( uv_data_test, freq_data_test, gridded_data_test, sky_dist_test, ) = open_data(fits_files, sky_dist, conf, i) truth_fft_test = calc_truth_fft(sky_dist_test) if conf["amp_phase"]: gridded_data_test = convert_amp_phase(gridded_data_test, sky_sim=False) truth_amp_phase_test = convert_amp_phase(truth_fft_test, sky_sim=True) else: gridded_data_test = convert_real_imag(gridded_data_test, sky_sim=False) truth_amp_phase_test = convert_real_imag(truth_fft_test, sky_sim=True) assert gridded_data_test.shape[1] == 2 out = out_path / Path("samp_test" + str(i) + ".h5") save_fft_pair(out, gridded_data_test, truth_amp_phase_test) size_train = int(size // (1 + conf["train_valid_split"])) size_valid = size - size_train print(f"Training size: {size_train}, Validation size: {size_valid}") bundle_train = int(size_train // conf["bundle_size"]) bundle_valid = int(size_valid // conf["bundle_size"]) # train for i in tqdm(range(bundle_train)): i += bundle_test uv_data_train, freq_data_train, gridded_data_train, sky_dist_train = open_data( fits_files, sky_dist, conf, i ) truth_fft_train = calc_truth_fft(sky_dist_train) if conf["amp_phase"]: gridded_data_train = convert_amp_phase(gridded_data_train, sky_sim=False) truth_amp_phase_train = convert_amp_phase(truth_fft_train, sky_sim=True) else: gridded_data_train = convert_real_imag(gridded_data_train, sky_sim=False) truth_amp_phase_train = convert_real_imag(truth_fft_train, sky_sim=True) out = out_path / Path("samp_train" + str(i - bundle_test) + ".h5") save_fft_pair(out, gridded_data_train, truth_amp_phase_train) train_index_last = i # valid for i in tqdm(range(bundle_valid)): i += train_index_last uv_data_valid, freq_data_valid, gridded_data_valid, sky_dist_valid = open_data( fits_files, sky_dist, conf, i ) truth_fft_valid = calc_truth_fft(sky_dist_valid) if conf["amp_phase"]: gridded_data_valid = convert_amp_phase(gridded_data_valid, sky_sim=False) truth_amp_phase_valid = convert_amp_phase(truth_fft_valid, sky_sim=True) else: gridded_data_valid = convert_real_imag(gridded_data_valid, sky_sim=False) truth_amp_phase_valid = convert_real_imag(truth_fft_valid, sky_sim=True) out = out_path / Path("samp_valid" + str(i - train_index_last) + ".h5") save_fft_pair(out, gridded_data_valid, truth_amp_phase_valid)
[docs] def open_data(fits_files, sky_dist, conf, i): sky_sim_bundle_size = len(open_bundles(sky_dist[0])) uv_data = [ fits_files.get_uv_data(n).copy() for n in np.arange( i * sky_sim_bundle_size, (i * sky_sim_bundle_size) + sky_sim_bundle_size ) ] freq_data = np.array( [ fits_files.get_freq_data(n) for n in np.arange( i * sky_sim_bundle_size, (i * sky_sim_bundle_size) + sky_sim_bundle_size ) ], dtype="object", ) gridded_data = np.array( [grid_data(data, freq, conf).copy() for data, freq in zip(uv_data, freq_data)] ) bundle = np.floor_divide(i * sky_sim_bundle_size, sky_sim_bundle_size) gridded_truth = np.array( [ open_bundles(sky_dist[bundle])[n] for n in np.arange( i * sky_sim_bundle_size - bundle * sky_sim_bundle_size, (i * sky_sim_bundle_size) + sky_sim_bundle_size - bundle * sky_sim_bundle_size, ) ] ) return uv_data, freq_data, gridded_data, gridded_truth
[docs] def save_fft_pair(path, x, y, name_x="x", name_y="y"): """ write fft_pairs created in second analysis step to h5 file """ half_image = x.shape[2] // 2 x = x[:, :, : half_image + 1, :] y = y[:, :, : half_image + 1, :] test_shapes(x, "x") test_shapes(y, "y") with h5py.File(path, "w") as hf: hf.create_dataset(name_x, data=x) hf.create_dataset(name_y, data=y) hf.close()
def test_shapes(array, name): if array.shape[1] != 2: raise ValueError( f"Expected array {name} axis 1 to be 2 but got " f"{array.shape} with axis 1: {array.shape[1]}!" ) if len(array.shape) != 4: raise ValueError( f"Expected array {name} shape to be of len 4 but got " f"{array.shape} with len {len(array.shape)}!" )
[docs] def calc_truth_fft(sky_dist): truth_fft = np.fft.fftshift( np.fft.fft2(np.fft.fftshift(sky_dist, axes=(2, 3)), axes=(2, 3)), axes=(2, 3) ) return truth_fft
[docs] def convert_amp_phase(data, sky_sim=False): if sky_sim: amp = np.abs(data) phase = np.angle(data) data = np.concatenate((amp, phase), axis=1) else: test = data[:, 0] + 1j * data[:, 1] amp = np.abs(test) phase = np.angle(test) data = np.stack((amp, phase), axis=1) return data
[docs] def convert_real_imag(data, sky_sim=False): if sky_sim: real = data.real imag = data.imag data = np.concatenate((real, imag), axis=1) else: real = data[:, 0] imag = data[:, 1] data = np.stack((real, imag), axis=1) return data