File size: 5,395 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Unified experiment configuration.

A single dataclass drives every run. Values can come from (in priority order):
  1. command-line flags (argparse)  2. a YAML file (--config)  3. dataclass defaults.

The same config object is used by train.py / test.py so that a training run and
its evaluation are guaranteed to agree on dataset, model, image size, etc.
"""
from __future__ import annotations

import argparse
import dataclasses
from dataclasses import dataclass, field, asdict
from typing import Optional, List

import yaml


@dataclass
class Config:
    # ---- experiment identity ----
    exp_name: str = "default"                 # results/<exp_name>/<dataset>/<arch>/seed<seed>/
    seed: int = 0

    # ---- data ----
    data_root: str = "dataset/processed_unified"
    dataset: str = "cvc_clinicdb"             # folder name under data_root
    protocol: str = "official"                # e.g. official / fold01 ...
    in_channels: int = 0                      # 0 = auto-detect from metadata/first image
    num_classes: int = 0                      # 0 = auto-detect from metadata/masks (incl. background)
    img_size: int = 256                       # square resize target (Swin/TransUNet need 224)
    # extra synthetic (image,mask) pairs to MERGE into the train split.
    # Points at a dir laid out like a split: <synth_train_dir>/{images,masks}/.
    synth_train_dir: str = ""                 # "" = real data only (no generative augmentation)

    # ---- augmentation (conventional baseline tier) ----
    aug: str = "standard"                     # none | standard | strong  (albumentations online)
    aug_backend: str = "albumentations"       # albumentations | monai
    normalize: str = "auto"                   # auto(imagenet for RGB, 0.5 for gray) | imagenet | none

    # ---- model ----
    arch: str = "unet"                        # see models/registry.py REGISTRY
    encoder: str = "resnet34"                 # SMP encoder name (ignored by non-SMP archs)
    encoder_weights: str = "imagenet"         # imagenet | none
    pretrained_ckpt: str = ""                 # ViT/Swin pretrain for transunet/swinunet (optional)

    # ---- optimization ----
    epochs: int = 100
    batch_size: int = 16                      # per-GPU batch size
    lr: float = 1e-4
    weight_decay: float = 1e-4
    optimizer: str = "adamw"                  # adamw | sgd
    scheduler: str = "poly"                   # poly | cosine | none
    warmup_epochs: int = 0
    loss: str = "ce_dice"                     # ce_dice | ce | dice
    num_workers: int = 8
    grad_clip: float = 0.0                    # 0 = disabled

    # ---- precision / hardware ----
    amp: str = "bf16"                         # bf16(A100+) | fp16(V100) | fp32
    # DDP is driven by torchrun env vars (RANK/WORLD_SIZE/LOCAL_RANK); nothing to set here.

    # ---- evaluation / logging ----
    val_interval: int = 5                     # epochs between validations
    min_epochs: int = 0                       # never early-stop before this many epochs
    patience: int = 0                          # early-stop after this many epochs w/o val improvement (0 = off)
    save_interval: int = 0                    # 0 = only save best + last
    include_background: bool = False          # include class 0 in reported Dice/IoU
    compute_hd95: bool = True
    out_root: str = "results"
    resume: str = ""                          # path to checkpoint to resume from
    visualize: bool = True                    # save overlays at test time
    vis_max: int = 32                         # max number of overlay images to save

    def out_dir(self) -> str:
        return f"{self.out_root}/{self.exp_name}/{self.dataset}_{self.protocol}/{self.arch}/seed{self.seed}"

    def to_yaml(self, path: str) -> None:
        with open(path, "w") as f:
            yaml.safe_dump(asdict(self), f, sort_keys=False, allow_unicode=True)

    @classmethod
    def from_args(cls, argv: Optional[List[str]] = None) -> "Config":
        # First pass: only grab --config so YAML can set defaults that flags then override.
        pre = argparse.ArgumentParser(add_help=False)
        pre.add_argument("--config", type=str, default="")
        known, _ = pre.parse_known_args(argv)

        base = cls()
        if known.config:
            with open(known.config) as f:
                ydata = yaml.safe_load(f) or {}
            base = dataclasses.replace(base, **{k: v for k, v in ydata.items()
                                                if k in {f.name for f in dataclasses.fields(cls)}})

        p = argparse.ArgumentParser(parents=[pre],
                                    description="SegGen unified segmentation framework")
        for f in dataclasses.fields(cls):
            default = getattr(base, f.name)
            if f.type is bool or isinstance(default, bool):
                # support --flag / --no-flag
                p.add_argument(f"--{f.name}", dest=f.name, action="store_true", default=default)
                p.add_argument(f"--no-{f.name}", dest=f.name, action="store_false")
            else:
                p.add_argument(f"--{f.name}", type=type(default) if default is not None else str,
                               default=default)
        ns = p.parse_args(argv)
        kwargs = {f.name: getattr(ns, f.name) for f in dataclasses.fields(cls)}
        return cls(**kwargs)