File size: 3,187 Bytes
34d63ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
"""Plot debug_pass{N}.npz files saved by predict.py when MIMI_DEBUG=1."""
import sys
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

_HERE = Path(__file__).resolve().parent

# class index → label (from training config)
CLASS_NAMES = ["bos", "system_end", "user_end", "system", "user"]
CLASS_COLORS = ["#9E9E9E", "#FF9800", "#8BC34A", "#03A9F4", "#2196F3"]
IDX_USER = 4


def plot_npz(path: Path) -> None:
    d = np.load(path)
    subj  = d["subj"]
    other = d["other"]
    floor = d["floor"].astype(np.float32)
    if "probs" in d.files:
        probs = d["probs"]      # (T, 5) — new format
    else:
        p_user = d["p_user"]    # old format: only P(user) was stored
        probs = np.zeros((len(p_user), 5), dtype=np.float32)
        probs[:, IDX_USER] = p_user
    thr   = float(d["threshold"])
    sr    = int(d["sr"])
    hz    = int(d["frame_rate"])

    n_steps  = len(floor)
    duration = n_steps / hz
    t_wav    = np.arange(len(subj)) / sr
    t_frame  = np.arange(n_steps) / hz

    fig, axes = plt.subplots(4, 1, figsize=(max(20, duration / 5), 10),
                              gridspec_kw={"hspace": 0.05, "height_ratios": [1, 1, 2, 1]})

    # waveforms
    for ax, wav, label in zip(axes[:2], [subj, other], ["subject", "other"]):
        ax.plot(t_wav, wav, lw=0.3, color="#555", alpha=0.8)
        ax.set_ylabel(label, fontsize=8)
        ax.set_xlim(0, duration)
        ax.tick_params(labelbottom=False)

    # all 5 class probabilities
    ax_prob = axes[2]
    for idx, (name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)):
        lw = 1.2 if idx == IDX_USER else 0.7
        alpha = 0.9 if idx == IDX_USER else 0.7
        ax_prob.plot(t_frame, probs[:, idx], lw=lw, color=color, alpha=alpha, label=name)
    ax_prob.fill_between(t_frame, probs[:, IDX_USER], alpha=0.1, color=CLASS_COLORS[IDX_USER])
    ax_prob.axhline(thr, color="red", lw=0.8, ls="--", label=f"threshold={thr}")
    ax_prob.set_ylim(-0.05, 1.05)
    ax_prob.set_ylabel("class probs", fontsize=8)
    ax_prob.legend(fontsize=7, loc="upper right", ncol=3)
    ax_prob.set_xlim(0, duration)
    ax_prob.tick_params(labelbottom=False)

    # floor bit
    axes[3].step(t_frame, floor, where="post", lw=0.8, color="#4CAF50")
    axes[3].fill_between(t_frame, floor, step="post", alpha=0.2, color="#4CAF50")
    axes[3].set_ylim(-0.1, 1.1)
    axes[3].set_ylabel("floor bit", fontsize=8)
    axes[3].set_xlabel("Time (s)", fontsize=9)
    axes[3].set_xlim(0, duration)

    for ax in axes:
        for vl in np.arange(1, int(duration) + 1):
            ax.axvline(vl, color="#ddd", lw=0.4, zorder=0)

    axes[0].set_title(f"mimi_endpointer  {path.stem}  threshold={thr}", fontsize=9, loc="left", pad=3)
    out = path.with_suffix(".png")
    fig.savefig(out, dpi=100, bbox_inches="tight")
    plt.close(fig)
    print(f"saved → {out}")


if __name__ == "__main__":
    npz_files = sorted(_HERE.glob("debug_pass*.npz"))
    if not npz_files:
        print("No debug_pass*.npz files found. Run with MIMI_DEBUG=1 first.")
        sys.exit(1)
    for f in npz_files:
        plot_npz(f)