import inspect
import os
import tomllib
from collections.abc import Callable
from pathlib import Path
from typing import Annotated, Literal
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from pyvisgen.io import datawriters
from pyvisgen.layouts import get_array_names
__all__ = [
"Config",
"NoiseConfig",
"SamplingConfig",
"PolarizationConfig",
"BundleConfig",
]
[docs]
class SamplingConfig(BaseModel, validate_assignment=True):
"""Sampling config BaseModel"""
mode: Literal["full", "grid", "dense"] = "full"
device: str = "cuda"
seed: str | bool | int | None = 1337
layout: str = "vlba"
img_size: int = Field(default=1024, gt=0)
fov_center_ra: list[float] = [100, 110]
fov_center_dec: list[float] = [30, 40]
fov_size: float = Field(default=0.24, gt=0)
corr_int_time: float = Field(default=30.0, gt=0)
scan_start: list[str] = ["01-01-1995 00:00:01", "01-01-2025 23:59:59"]
scan_duration: list[int] = [20, 600]
num_scans: list[int] = [6, 10]
scan_separation: float = Field(default=360, ge=0)
ref_frequency: float = Field(default=15.17600e9, gt=0)
frequency_offsets: list[float] = [0e8, 1.28e8, 2.56e8, 3.84e8]
bandwidths: list[float] = [1.28e8, 1.28e8, 1.28e8, 1.28e8]
normalize: bool = True
corrupted: bool = False
sensitivity_cut: float = Field(default=1e-6, ge=0)
[docs]
@field_validator("layout")
@classmethod
def validate_layout(cls, layout: str) -> None:
_avail_layouts = get_array_names()
if layout not in _avail_layouts:
raise ValueError(
f"expected 'layout' to be one of {_avail_layouts} but got {layout}"
)
return layout
[docs]
@field_validator("scan_start")
@classmethod
def validate_dates(cls, v: list[str]) -> None:
if len(v) != 2:
raise ValueError("expected 'scan_start' to be a list of len 2")
return v
[docs]
@field_validator("seed")
@classmethod
def parse_seed(cls, v: str | bool | int | None) -> int | None:
if v in {"none", False}:
v = None
return v
[docs]
class NoiseConfig(BaseModel, validate_assignment=True):
"""Noise simulation config BaseModel"""
noise_level: float = Field(default=0, ge=0)
noise_mode: Literal["sefd", "tsys"] = "sefd"
telescope: str = "meerkat"
band: str | None = None # None → first band defined in the telescope config
[docs]
class PolarizationConfig(BaseModel, validate_assignment=True):
"""Polarization config BaseModel"""
mode: Literal["linear", "circular", "none"] | None = None
delta: float | None = Field(default=45)
amp_ratio: Annotated[float, Field(ge=0.0, le=1.0)] | None = Field(default=0.5)
field_order: list[float] | None = [0.01, 0.01]
field_scale: list[float] | None = [0.0, 1.0]
field_threshold: float | None = None
[docs]
@field_validator(
"mode",
"delta",
"amp_ratio",
"field_order",
"field_scale",
"field_threshold",
mode="before",
)
@classmethod
def parse_mode_thresh(
cls, v: str | float | list | None
) -> str | float | list | None:
if v == "none":
v = None
return v
[docs]
class BundleConfig(BaseModel, validate_assignment=True):
"""Bundle config BaseModel"""
dataset_type: Literal["train", "test", "valid", "none", ""] = "train"
in_path: str | Path = "./path/to/input/data/"
out_path: str | Path = "./output/path/"
fits_out_path: str | Path | None = None
"""Optional secondary UVFITS output directory for the test WSClean pipeline.
When set, a UVFITS file is written alongside each UVH5 file during
simulation. Only valid when ``writer = "UVH5Writer"`` in ``[datawriter]``;
using it together with ``writer = "FITSWriter"`` raises a ``ValueError``.
"""
overlap: int = 5
grid_size: int = Field(default=1024, gt=0)
grid_fov: float = Field(default=0.24, gt=0)
amp_phase: bool = False
[docs]
@field_validator("in_path", "out_path")
@classmethod
def expand_path(cls, v: str | Path, info: ValidationInfo) -> Path:
"""Expand and resolve paths."""
if v in {"none", ""}:
raise ValueError(f"'{info.field_name}' cannot be empty!")
v = Path(v).expanduser().resolve()
return v
[docs]
@field_validator("fits_out_path", mode="before")
@classmethod
def parse_fits_out_path(cls, v: str | Path | None) -> Path | None:
if v in {"none", "", None}:
return None
return Path(v).expanduser().resolve()
class DataWriterConfig(BaseModel, validate_assignment=True):
writer: str | Callable = datawriters.H5Writer
overlap: int = Field(default=5, gt=0)
shard_pattern: str = "%06d.tar"
compress: bool = False
@field_validator("writer")
@classmethod
def setup_writer(cls, writer) -> Callable:
if isinstance(writer, Callable) and issubclass(writer, datawriters.DataWriter):
return writer
_avail_writers = {}
for member in inspect.getmembers(datawriters):
if inspect.isclass(member[1]):
_avail_writers[member[0]] = member[1]
# handle shorthands for full data writer names
if writer.lower() in ["h5", "hdf5"]:
output_writer = _avail_writers["H5Writer"]
elif writer.lower() in ["uvh5"]:
output_writer = _avail_writers["UVH5Writer"]
elif writer.lower() in ["wds", "webdataset"]:
output_writer = _avail_writers["WDSShardWriter"]
elif writer.lower() in ["pt"]:
output_writer = _avail_writers["PTWriter"]
else:
output_writer = _avail_writers[writer]
return output_writer
class GriddingConfig(BaseModel, validate_assignment=True):
gridder: str = "default"
class FFTConfig(BaseModel, validate_assignment=True):
ft: Literal["default", "finufft", "reversed"] = "default"
class CodeCarbonEmissionTrackerConfig(BaseModel, validate_assignment=True):
"""Codecarbon emission tracker configuration"""
log_level: str | int = "error"
country_iso_code: str = "DEU"
output_dir: str | None = os.getcwd()
[docs]
class Config(BaseModel):
"""Main training configuration."""
sampling: SamplingConfig = Field(default_factory=SamplingConfig)
noise: NoiseConfig = Field(default_factory=NoiseConfig)
polarization: PolarizationConfig = Field(default_factory=PolarizationConfig)
bundle: BundleConfig = Field(default_factory=BundleConfig)
datawriter: DataWriterConfig = Field(default_factory=DataWriterConfig)
gridding: GriddingConfig = Field(default_factory=GriddingConfig)
fft: FFTConfig = Field(default_factory=FFTConfig)
codecarbon: bool | CodeCarbonEmissionTrackerConfig = False
[docs]
@classmethod
def from_toml(cls, path: str | Path) -> "Config":
"""Load configuration from a TOML file."""
with open(path, "rb") as f:
data = tomllib.load(f)
return cls(**data)
[docs]
@model_validator(mode="after")
def check_fits_out_path_writer(self) -> "Config":
if self.bundle.fits_out_path is not None and issubclass(
self.datawriter.writer, datawriters.FITSWriter
):
raise ValueError(
"'fits_out_path' in [bundle] must not be used together with "
"writer='FITSWriter' in [datawriter] — the FITSWriter already "
"writes FITS files as its primary output. "
"Set writer='UVH5Writer' to enable secondary FITS output."
)
return self
[docs]
@field_validator("codecarbon", mode="before")
@classmethod
def validate_codecarbon(cls, v: bool | CodeCarbonEmissionTrackerConfig):
if isinstance(v, dict):
return CodeCarbonEmissionTrackerConfig(**v, project_name="pyvisgen")
elif v is True:
return CodeCarbonEmissionTrackerConfig(project_name="pyvisgen")
return v