File size: 3,855 Bytes
1a18f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Same-mask JiT(native) vs JiT-FD comparison: does FD-perceptual sharpen the synth?
Samples JiT-FD on the SAME no-aug f50 masks the native JiT align-set used, builds
[mask | real | JiT native | JiT-FD] grids for ISIC + Kvasir."""
import os, subprocess
import numpy as np
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image

ROOT = "/home/wzhang/LSC/Code/NPJ"; DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified"
PY = "/opt/anaconda3/envs/seggen/bin/python"
DSETS = {"isic": ("medsegdb_isic2018", "holdout", 2582), "kvasir": ("kvasir_seg", "official", 800)}

def sample(ckpt, ds, proto, frac, out):
    if os.path.isdir(out + "/images") and len(os.listdir(out + "/images")) >= 40:
        print(f"[skip-sample] {out} exists", flush=True); return
    env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES="0",
               TORCHDYNAMO_DISABLE="1", PYTHONPATH=".", OMP_NUM_THREADS="4")
    subprocess.run([PY, "-m", "framework.synth.pixdiff.sample", "--ckpt", ckpt, "--data_root", DR,
                    "--dataset", ds, "--protocol", proto, "--train_fraction", str(frac),
                    "--fraction_seed", "0", "--n_per_mask", "1", "--num_steps", "50", "--out_dir", out],
                   env=env, cwd=ROOT, check=True)
    print(f"[sampled] {out}", flush=True)

def fmap(d):
    p = os.path.join(d, "images"); m = {}
    if os.path.isdir(p):
        for f in sorted(os.listdir(p)):
            if f.endswith(".png"): m.setdefault(f[:-4].split("__")[0], os.path.join(p, f))
    return m
def rgb(p): return np.asarray(Image.open(p).convert("RGB").resize((256, 256)))
def gray(p): return np.asarray(Image.open(p).convert("L").resize((256, 256)))

for dk, (ds, proto, tot) in DSETS.items():
    f50 = 50 / tot
    fd_out = f"{DR}/{ds}/{proto}/synth_alignfd_jitfd_{dk}"
    sample(f"pretrained/pixdiff/p1_jitfd_{dk}.pt", ds, proto, f50, fd_out)
    base = f"{DR}/{ds}/{proto}"; ri, rm = f"{base}/train/images", f"{base}/train/masks"
    nat = fmap(f"{base}/synth_align_jit_{dk}"); fd = fmap(fd_out)
    common = set(os.path.splitext(f)[0] for f in os.listdir(ri) if f.endswith(".png")) & set(nat) & set(fd)
    common = sorted(common); ncol = min(6, len(common))
    idx = [round(i * (len(common) - 1) / (ncol - 1)) for i in range(ncol)] if ncol > 1 else [0]
    cases = [common[i] for i in idx]
    rows = [("Conditioning mask", "mask"), ("Real", "real"), ("JiT (native, P1)", nat), ("JiT-FD (FD-感知)", fd)]
    fig, ax = plt.subplots(len(rows), ncol, figsize=(ncol * 2.1, len(rows) * 2.15))
    for r, (lab, src) in enumerate(rows):
        for c, bs in enumerate(cases):
            a = ax[r][c]
            try:
                mk = gray(f"{rm}/{bs}.png")
                if src == "mask": a.imshow(mk, cmap="gray")
                elif src == "real": a.imshow(rgb(f"{ri}/{bs}.png"))
                else: a.imshow(rgb(src[bs]))
                if src not in ("mask",): a.contour((mk > 127).astype(float), levels=[0.5], colors=["#19f04b"], linewidths=0.9)
            except Exception:
                a.imshow(np.ones((256, 256, 3))); a.text(0.5, 0.5, "n/a", ha="center", va="center", transform=a.transAxes)
            a.set_xticks([]); a.set_yticks([])
            for s in a.spines.values(): s.set_visible(False)
            if c == 0: a.set_ylabel(lab, fontsize=11, rotation=90, va="center", labelpad=8,
                                    color=("#111" if r < 2 else "#1a3b8b"), fontweight=("bold" if r == 3 else "normal"))
    fig.suptitle(f"{dk.upper()} — 原生 JiT vs JiT-FD(同掩码):FD-感知精修是否更锐?", fontsize=12)
    plt.tight_layout(rect=[0.03, 0, 1, 0.95])
    out = f"/tmp/jit_vs_fd_{dk}.png"; plt.savefig(out, dpi=150, bbox_inches="tight", facecolor="white")
    print(f"[grid] {out}", flush=True)
print("JIT_VS_FD_DONE", flush=True)