Source code for pyvisgen.io.config

import inspect
import os
import tomllib
from collections.abc import Callable
from pathlib import Path
from typing import Literal

from pydantic import BaseModel, Field, ValidationInfo, field_validator

from pyvisgen.io import datawriters
from pyvisgen.layouts import get_array_names

__all__ = [
    "Config",
    "SamplingConfig",
    "PolarizationConfig",
    "BundleConfig",
]


[docs] class SamplingConfig(BaseModel, validate_assignment=True): """Sampling config BaseModel""" mode: Literal["full", "grid", "dense"] = "full" device: str = "cuda" seed: str | bool | int | None = 1337 layout: str = "vlba" img_size: int = Field(default=1024, gt=0) fov_center_ra: list[int] = [100, 110] fov_center_dec: list[int] = [30, 40] fov_size: float = Field(default=0.24, gt=0) corr_int_time: float = Field(default=30.0, gt=0) scan_start: list[str] = ["01-01-1995 00:00:01", "01-01-2025 23:59:59"] scan_duration: list[int] = [20, 600] num_scans: list[int] = [6, 10] scan_separation: float = Field(default=360, ge=0) ref_frequency: float = Field(default=15.17600e9, gt=0) frequency_offsets: list[float] = [0e8, 1.28e8, 2.56e8, 3.84e8] bandwidths: list[float] = [1.28e8, 1.28e8, 1.28e8, 1.28e8] noisy: int = Field(default=0, ge=0) corrupted: bool = False sensitivity_cut: float = Field(default=1e-6, ge=0)
[docs] @field_validator("layout") @classmethod def validate_layout(cls, layout: str) -> None: _avail_layouts = get_array_names() if layout not in _avail_layouts: raise ValueError( f"expected 'layout' to be one of {_avail_layouts} but got {layout}" ) return layout
[docs] @field_validator("scan_start") @classmethod def validate_dates(cls, v: list[str]) -> None: if len(v) != 2: raise ValueError("expected 'scan_start' to be a list of len 2") return v
[docs] @field_validator("seed") @classmethod def parse_seed(cls, v: str | bool | int | None) -> int | None: if v in {"none", False}: v = None return v
[docs] class PolarizationConfig(BaseModel, validate_assignment=True): """Polarization config BaseModel""" mode: Literal["linear", "circular", "none"] | None = None delta: float = Field(default=45) amp_ratio: float = Field(default=0.5) field_order: list[float] = [0.01, 0.01] field_scale: list[float] = [0, 1] field_threshold: float | Literal["none"] | None = None
[docs] @field_validator("mode", "field_threshold") @classmethod def parse_mode_thresh(cls, v: str | float) -> str | float | None: if v == "none": v = None return v
[docs] class BundleConfig(BaseModel, validate_assignment=True): """Bundle config BaseModel""" dataset_type: Literal["train", "test", "valid", "none", ""] = "train" in_path: str | Path = "./path/to/input/data/" out_path: str | Path = "./output/path/" overlap: int = 5 grid_size: int = Field(default=1024, gt=0) grid_fov: float = Field(default=0.24, gt=0) amp_phase: bool = False
[docs] @field_validator("in_path", "out_path") @classmethod def expand_path(cls, v: Path, info: ValidationInfo) -> Path: """Expand and resolve paths.""" if v in {None, False, "none", ""}: raise ValueError(f"'{info.field_name}' cannot be empty!") else: v = Path(v) v.expanduser().resolve() return v
class DataWriterConfig(BaseModel, validate_assignment=True): writer: str | Callable = datawriters.H5Writer overlap: int = Field(default=5, gt=0) shard_pattern: str = "%06d.tar" compress: bool = False @field_validator("writer") @classmethod def setup_writer(cls, writer) -> Callable: if isinstance(writer, Callable) and issubclass(writer, datawriters.DataWriter): return writer _avail_writers = {} for member in inspect.getmembers(datawriters): if inspect.isclass(member[1]): _avail_writers[member[0]] = member[1] # handle shorthands for full data writer names if writer.lower() in ["h5", "hdf5"]: output_writer = _avail_writers["H5Writer"] elif writer.lower() in ["wds", "webdataset"]: output_writer = _avail_writers["WDSShardWriter"] elif writer.lower() in ["pt", "ptx"]: output_writer = _avail_writers["PTWriter"] else: output_writer = _avail_writers[writer] return output_writer class GriddingConfig(BaseModel, validate_assignment=True): gridder: str = "default" class FFTConfig(BaseModel, validate_assignment=True): ft: Literal["default", "finufft", "reversed"] = "default" class CodeCarbonEmissionTrackerConfig(BaseModel, validate_assignment=True): """Codecarbon emission tracker configuration""" log_level: str | int = "error" country_iso_code: str = "DEU" output_dir: str | None = os.getcwd()
[docs] class Config(BaseModel): """Main training configuration.""" sampling: SamplingConfig = Field(default_factory=SamplingConfig) polarization: PolarizationConfig = Field(default_factory=PolarizationConfig) bundle: BundleConfig = Field(default_factory=BundleConfig) datawriter: DataWriterConfig = Field(default_factory=DataWriterConfig) gridding: GriddingConfig = Field(default_factory=GriddingConfig) fft: FFTConfig = Field(default_factory=FFTConfig) codecarbon: bool | CodeCarbonEmissionTrackerConfig = False
[docs] @classmethod def from_toml(cls, path: str | Path) -> "Config": """Load configuration from a TOML file.""" with open(path, "rb") as f: data = tomllib.load(f) return cls(**data)
[docs] def to_dict(self) -> dict: """Export configuration as a dictionary.""" return self.model_dump()
[docs] @field_validator("codecarbon", mode="after") @classmethod def validate_codecarbon(cls, v: bool | CodeCarbonEmissionTrackerConfig): if isinstance(v, dict): # pragma: no cover return CodeCarbonEmissionTrackerConfig(**v, project_name="pyvisgen") elif v is True: return CodeCarbonEmissionTrackerConfig(project_name="pyvisgen") return v