File size: 2,202 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Training entrypoint.

Single GPU:   python framework/train.py --dataset cvc_clinicdb --arch unet ...
Multi-GPU :   torchrun --nproc_per_node=4 framework/train.py --dataset ... --arch ...
"""
from __future__ import annotations

import os
import sys

# allow `python framework/train.py` (add repo root to path)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import torch
import cv2

# Each DataLoader worker single-threaded for OpenCV; parallelism comes from num_workers.
# Without this, cv2 spawns an nproc-sized (~384) thread pool per worker, whose per-op
# dispatch overhead starves the GPU at high resolution (768) -> ~4x slower epochs.
cv2.setNumThreads(1)

from framework.config import Config
from framework.engine.distributed import setup_distributed, cleanup_distributed, set_seed, print_main
from framework.models.registry import build_model, required_img_size
from framework.engine.trainer import Trainer


def main():
    cfg = Config.from_args()

    # some backbones require a fixed input size
    req = required_img_size(cfg.arch)
    if req and cfg.img_size != req:
        print_main(f"[info] arch '{cfg.arch}' requires img_size={req}; overriding {cfg.img_size}.")
        cfg.img_size = req

    local_rank = setup_distributed()
    set_seed(cfg.seed, rank=local_rank)

    # peek dataset to get in/out channels before building the model
    from framework.data.loaders import build_dataset
    probe = build_dataset(cfg, "train")
    in_ch, n_cls = probe.in_channels, probe.num_classes
    print_main(f"[data] {cfg.dataset}/{cfg.protocol}: in_channels={in_ch} num_classes={n_cls} "
               f"train={len(probe)}")

    model = build_model(cfg.arch, in_channels=in_ch, num_classes=n_cls,
                        img_size=cfg.img_size, encoder=cfg.encoder,
                        encoder_weights=cfg.encoder_weights,
                        pretrained_ckpt=cfg.pretrained_ckpt)
    print_main(f"[model] {cfg.arch} params={sum(p.numel() for p in model.parameters())/1e6:.1f}M "
               f"amp={cfg.amp}")

    trainer = Trainer(cfg, model, local_rank)
    trainer.fit()
    cleanup_distributed()


if __name__ == "__main__":
    main()