File size: 2,590 Bytes
1a18f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Generation-side Precision/Recall (Kynkaanniemi) + Density/Coverage (Naeem) on
InceptionV3 (pytorch-fid) features. Precision=fidelity (fake in real manifold),
Recall=diversity/coverage (real covered by fake). No sklearn: kNN via torch.cdist."""
import os, sys, json, random
import numpy as np, torch
from PIL import Image
sys.path.insert(0, "/home/wzhang/LSC/Code/NPJ")
from framework.synth.pixdiff.fd_loss import InceptionFeatures

DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified"
DSETS = {"isic": ("medsegdb_isic2018", "holdout"), "kvasir": ("kvasir_seg", "official"), "busi": ("busi", "fold01")}
BKS = ["jit", "pixelgen", "deco", "pixeldit"]
dev = "cuda"; CAP = 2000; K = 5
inc = InceptionFeatures().to(dev).eval()

def feats(d, cap=CAP):
    fs = sorted(f for f in os.listdir(d) if f.lower().endswith((".png", ".jpg", ".jpeg")))
    if len(fs) > cap:
        random.seed(0); fs = random.sample(fs, cap)
    out = []
    for i in range(0, len(fs), 64):
        b = []
        for f in fs[i:i + 64]:
            im = Image.open(os.path.join(d, f)).convert("RGB").resize((256, 256))
            b.append(torch.from_numpy(np.asarray(im)).permute(2, 0, 1).float() / 255.)
        with torch.no_grad():
            out.append(inc(torch.stack(b).to(dev)).cpu())
    return torch.cat(out)

def knn_radius(X, k):
    d = torch.cdist(X, X); d.fill_diagonal_(float("inf")); return d.kthvalue(k, dim=1).values

def prdc(R, F, k=K):
    R, F = R.to(dev), F.to(dev)
    rr = knn_radius(R, k); ff = knn_radius(F, k); drf = torch.cdist(R, F)
    prec = (drf <= rr[:, None]).any(0).float().mean().item()
    rec = (drf <= ff[None, :]).any(1).float().mean().item()
    dens = ((drf <= rr[:, None]).sum(0).float().mean() / k).item()
    cov = (drf <= rr[:, None]).any(1).float().mean().item()
    return prec, rec, dens, cov

realf = {}
for dk, (ds, proto) in DSETS.items():
    realf[dk] = feats(f"{DR}/{ds}/{proto}/train/images")
    print(f"[real] {dk}: {realf[dk].shape}", flush=True)

res = {}
for bk in BKS:
    for dk, (ds, proto) in DSETS.items():
        sd = f"{DR}/{ds}/{proto}/synth_fid_{bk}_{dk}/images"
        if not os.path.isdir(sd):
            print(f"[skip] {dk} {bk}"); continue
        F = feats(sd)
        p, r, de, c = prdc(realf[dk], F)
        res[f"{dk}_{bk}"] = {"precision": round(p, 3), "recall": round(r, 3), "density": round(de, 3), "coverage": round(c, 3)}
        print(f"[PRDC] {dk} {bk}: {res[f'{dk}_{bk}']}", flush=True)
        json.dump(res, open("/home/wzhang/LSC/Code/NPJ/logs/fidviz/gen_prdc.json", "w"), indent=2)
print("PRDC_DONE", flush=True)