File size: 1,840 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Save a side-by-side overlay: input | ground-truth | prediction.

Used at test time to qualitatively inspect each method's output. Denormalizes the
input tensor back to a viewable image and color-codes class masks.
"""
from __future__ import annotations

import numpy as np
import cv2
import torch

_IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
_IMAGENET_STD = np.array([0.229, 0.224, 0.225])

# distinct colors for up to 6 classes (0 = background -> transparent/black)
_PALETTE = np.array([
    [0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255],
    [255, 255, 0], [255, 0, 255], [0, 255, 255],
], dtype=np.uint8)


def _denorm(img: torch.Tensor) -> np.ndarray:
    x = img.float().numpy()                       # C,H,W
    c = x.shape[0]
    x = np.transpose(x, (1, 2, 0))                # H,W,C
    if c == 3:
        x = x * _IMAGENET_STD + _IMAGENET_MEAN
    else:
        x = x * 0.5 + 0.5
        x = np.repeat(x, 3, axis=2) if x.shape[2] == 1 else x
    x = np.clip(x * 255.0, 0, 255).astype(np.uint8)
    return x


def _colorize(mask: np.ndarray, num_classes: int) -> np.ndarray:
    h, w = mask.shape
    out = np.zeros((h, w, 3), dtype=np.uint8)
    for c in range(1, num_classes):
        out[mask == c] = _PALETTE[c % len(_PALETTE)]
    return out


def save_overlay(img: torch.Tensor, gt: np.ndarray, pred: np.ndarray,
                 num_classes: int, path: str, alpha: float = 0.5) -> None:
    base = _denorm(img)
    h, w = gt.shape
    base = cv2.resize(base, (w, h), interpolation=cv2.INTER_LINEAR)
    gt_c = _colorize(gt, num_classes)
    pr_c = _colorize(pred, num_classes)
    gt_o = cv2.addWeighted(base, 1 - alpha, gt_c, alpha, 0)
    pr_o = cv2.addWeighted(base, 1 - alpha, pr_c, alpha, 0)
    panel = np.concatenate([base, gt_o, pr_o], axis=1)
    cv2.imwrite(path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))