File size: 4,196 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Re-score an existing framework checkpoint at a COMMON evaluation resolution R.

Needed for resolution-fair comparison: a model that takes a fixed input size
(SwinUNet=224, TransUNet=256) is run at its native input size, but its prediction
and the GROUND TRUTH (loaded at native, not the 256-degraded dataloader copy) are
both resized to R, and metrics are computed at R — matching how the conv methods
(trained at R) and nnU-Net/U-Mamba (re-scored with --eval_size R) are evaluated.

  python framework/eval_at_res.py --data_root <root> --dataset fives --protocol official \
     --arch swinunet --seed 0 --eval_size 768 --exp_name baselines
"""
from __future__ import annotations

import os
import sys
import json
import argparse

import numpy as np
import cv2
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from framework.models.registry import build_model, required_img_size
from framework.metrics.metrics import per_image_metrics, aggregate
from framework.data.unified_dataset import UnifiedSegDataset
from framework.data.transforms import build_transform


def _to_R(arr, R):
    return cv2.resize(arr.astype(np.uint8), (R, R), interpolation=cv2.INTER_NEAREST).astype(np.int64)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", required=True)
    ap.add_argument("--dataset", required=True)
    ap.add_argument("--protocol", required=True)
    ap.add_argument("--arch", required=True)
    ap.add_argument("--encoder", default="resnet50")
    ap.add_argument("--seed", type=int, required=True)
    ap.add_argument("--eval_size", type=int, required=True, help="common resolution R")
    ap.add_argument("--exp_name", default="baselines")
    ap.add_argument("--out_root", default="results")
    ap.add_argument("--normalize", default="auto")
    args = ap.parse_args()

    R = args.eval_size
    model_res = required_img_size(args.arch) or R          # SwinUNet 224 / TransUNet 256 / conv -> R

    ds = UnifiedSegDataset(args.data_root, args.dataset, args.protocol, "test", transform=None)
    ds.transform = build_transform(model_res, ds.in_channels, train=False,
                                   aug="none", normalize=args.normalize)
    num_classes = ds.num_classes

    out_dir = os.path.join(args.out_root, args.exp_name,
                           f"{args.dataset}_{args.protocol}", args.arch, f"seed{args.seed}")
    ckpt_path = os.path.join(out_dir, "best.pth")
    if not os.path.isfile(ckpt_path):
        raise SystemExit(f"checkpoint not found: {ckpt_path}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model(args.arch, in_channels=ds.in_channels, num_classes=num_classes,
                        img_size=model_res, encoder=args.encoder,
                        encoder_weights="none", pretrained_ckpt="")
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    model.load_state_dict(ckpt.get("model", ckpt))
    model = model.to(device).eval()

    records = []
    with torch.no_grad():
        for idx in range(len(ds)):
            item = ds[idx]
            img = item["image"].unsqueeze(0).to(device)         # 1,C,model_res,model_res
            pred = model(img).argmax(1)[0].cpu().numpy()        # model_res x model_res
            gt = cv2.imread(ds.pairs[idx][1], cv2.IMREAD_GRAYSCALE)  # native H x W, values 0..C-1
            records.append(per_image_metrics(_to_R(pred, R), _to_R(gt, R), num_classes,
                                             include_background=False, compute_hd95=True))

    agg = aggregate(records)
    out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch,
           "seed": args.seed, "num_classes": num_classes,
           "eval_size": R, "metrics": agg, "per_image": records}
    os.makedirs(out_dir, exist_ok=True)
    with open(os.path.join(out_dir, "metrics.json"), "w") as f:
        json.dump(out, f, indent=2)
    print(f"[eval_at_res] {args.dataset}/{args.protocol} {args.arch} seed{args.seed} @R={R}: "
          f"n={len(records)} dice={agg['dice_mean']:.4f} hd95={agg['hd95_mean']:.2f} -> {out_dir}")


if __name__ == "__main__":
    main()