File size: 4,324 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Compute framework-format metrics.json from nnU-Net / U-Mamba test predictions.

nnU-Net only reports validation Dice during training. To compare on the SAME
held-out test set with the SAME 7 metrics as the framework, we: predict on
imagesTs (done separately via nnUNetv2_predict), then run THIS script to score the
predicted masks against labelsTs using framework/metrics.py, writing a metrics.json
in the exact framework format so report/aggregate.py includes nnU-Net/U-Mamba rows.

  python framework/nnunet_eval.py --data_root <processed_unified> --dataset <ds> \
     --protocol <proto> --raw <nnUNet_raw> --dataset_id <ID> --fold <f> \
     --pred_dir <predictions> --arch nnunet --exp_name baselines
"""
from __future__ import annotations

import os
import sys
import json
import glob
import argparse

import numpy as np
import cv2

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from framework.metrics.metrics import per_image_metrics, aggregate
from framework.data.unified_dataset import _read_metadata, detect_num_classes


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", required=True)
    ap.add_argument("--dataset", required=True)
    ap.add_argument("--protocol", required=True)
    ap.add_argument("--raw", required=True, help="nnUNet_raw root")
    ap.add_argument("--dataset_id", type=int, required=True)
    ap.add_argument("--name", default="", help="Dataset<DDD>_<name> suffix; default <dataset>_<protocol>")
    ap.add_argument("--fold", type=int, required=True)
    ap.add_argument("--pred_dir", required=True)
    ap.add_argument("--arch", default="nnunet")
    ap.add_argument("--exp_name", default="baselines")
    ap.add_argument("--out_root", default="results")
    ap.add_argument("--include_background", action="store_true")
    ap.add_argument("--eval_size", type=int, default=0,
                    help="resize pred+gt to R×R (nearest) before scoring; 0 = native GT resolution")
    args = ap.parse_args()

    name = args.name or f"{args.dataset}_{args.protocol}"
    dsname = f"Dataset{args.dataset_id:03d}_{name}"
    lab_dir = os.path.join(args.raw, dsname, "labelsTs")

    meta = _read_metadata(args.data_root, args.dataset)
    gt_masks = sorted(glob.glob(os.path.join(lab_dir, "*.png")))
    num_classes = detect_num_classes(meta, gt_masks, args.dataset)

    records = []
    n_missing = 0
    for pp in sorted(glob.glob(os.path.join(args.pred_dir, "*.png"))):
        base = os.path.basename(pp)
        gp = os.path.join(lab_dir, base)
        if not os.path.isfile(gp):
            n_missing += 1
            continue
        pred = cv2.imread(pp, cv2.IMREAD_GRAYSCALE)
        gt = cv2.imread(gp, cv2.IMREAD_GRAYSCALE)
        if pred is None or gt is None:
            n_missing += 1
            continue
        if pred.shape != gt.shape:
            pred = cv2.resize(pred, (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST)
        if args.eval_size > 0:  # resolution-fair common-R scoring
            R = args.eval_size
            pred = cv2.resize(pred, (R, R), interpolation=cv2.INTER_NEAREST)
            gt = cv2.resize(gt, (R, R), interpolation=cv2.INTER_NEAREST)
        records.append(per_image_metrics(pred.astype(np.int64), gt.astype(np.int64),
                                         num_classes, include_background=args.include_background,
                                         compute_hd95=True))
    if not records:
        raise SystemExit(f"no matched (pred,gt) pairs in {args.pred_dir} vs {lab_dir}")

    agg = aggregate(records)
    out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch,
           "seed": args.fold, "num_classes": num_classes, "metrics": agg, "per_image": records}
    out_dir = os.path.join(args.out_root, args.exp_name, f"{args.dataset}_{args.protocol}",
                           args.arch, f"seed{args.fold}")
    os.makedirs(out_dir, exist_ok=True)
    with open(os.path.join(out_dir, "metrics.json"), "w") as f:
        json.dump(out, f, indent=2)
    print(f"[nnunet_eval] {dsname} fold{args.fold}: n={len(records)} (missing {n_missing}) "
          f"dice={agg['dice_mean']:.4f} iou={agg['iou_mean']:.4f} hd95={agg['hd95_mean']:.2f} "
          f"-> {out_dir}/metrics.json")


if __name__ == "__main__":
    main()