mimi-endpointer / plot_debug.py
viks66's picture
add plot_debug.py
34d63ae verified
#!/usr/bin/env python3
"""Plot debug_pass{N}.npz files saved by predict.py when MIMI_DEBUG=1."""
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
_HERE = Path(__file__).resolve().parent
# class index → label (from training config)
CLASS_NAMES = ["bos", "system_end", "user_end", "system", "user"]
CLASS_COLORS = ["#9E9E9E", "#FF9800", "#8BC34A", "#03A9F4", "#2196F3"]
IDX_USER = 4
def plot_npz(path: Path) -> None:
d = np.load(path)
subj = d["subj"]
other = d["other"]
floor = d["floor"].astype(np.float32)
if "probs" in d.files:
probs = d["probs"] # (T, 5) — new format
else:
p_user = d["p_user"] # old format: only P(user) was stored
probs = np.zeros((len(p_user), 5), dtype=np.float32)
probs[:, IDX_USER] = p_user
thr = float(d["threshold"])
sr = int(d["sr"])
hz = int(d["frame_rate"])
n_steps = len(floor)
duration = n_steps / hz
t_wav = np.arange(len(subj)) / sr
t_frame = np.arange(n_steps) / hz
fig, axes = plt.subplots(4, 1, figsize=(max(20, duration / 5), 10),
gridspec_kw={"hspace": 0.05, "height_ratios": [1, 1, 2, 1]})
# waveforms
for ax, wav, label in zip(axes[:2], [subj, other], ["subject", "other"]):
ax.plot(t_wav, wav, lw=0.3, color="#555", alpha=0.8)
ax.set_ylabel(label, fontsize=8)
ax.set_xlim(0, duration)
ax.tick_params(labelbottom=False)
# all 5 class probabilities
ax_prob = axes[2]
for idx, (name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)):
lw = 1.2 if idx == IDX_USER else 0.7
alpha = 0.9 if idx == IDX_USER else 0.7
ax_prob.plot(t_frame, probs[:, idx], lw=lw, color=color, alpha=alpha, label=name)
ax_prob.fill_between(t_frame, probs[:, IDX_USER], alpha=0.1, color=CLASS_COLORS[IDX_USER])
ax_prob.axhline(thr, color="red", lw=0.8, ls="--", label=f"threshold={thr}")
ax_prob.set_ylim(-0.05, 1.05)
ax_prob.set_ylabel("class probs", fontsize=8)
ax_prob.legend(fontsize=7, loc="upper right", ncol=3)
ax_prob.set_xlim(0, duration)
ax_prob.tick_params(labelbottom=False)
# floor bit
axes[3].step(t_frame, floor, where="post", lw=0.8, color="#4CAF50")
axes[3].fill_between(t_frame, floor, step="post", alpha=0.2, color="#4CAF50")
axes[3].set_ylim(-0.1, 1.1)
axes[3].set_ylabel("floor bit", fontsize=8)
axes[3].set_xlabel("Time (s)", fontsize=9)
axes[3].set_xlim(0, duration)
for ax in axes:
for vl in np.arange(1, int(duration) + 1):
ax.axvline(vl, color="#ddd", lw=0.4, zorder=0)
axes[0].set_title(f"mimi_endpointer {path.stem} threshold={thr}", fontsize=9, loc="left", pad=3)
out = path.with_suffix(".png")
fig.savefig(out, dpi=100, bbox_inches="tight")
plt.close(fig)
print(f"saved → {out}")
if __name__ == "__main__":
npz_files = sorted(_HERE.glob("debug_pass*.npz"))
if not npz_files:
print("No debug_pass*.npz files found. Run with MIMI_DEBUG=1 first.")
sys.exit(1)
for f in npz_files:
plot_npz(f)