#!/usr/bin/env python3 """Plot debug_pass{N}.npz files saved by predict.py when MIMI_DEBUG=1.""" import sys from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np _HERE = Path(__file__).resolve().parent # class index → label (from training config) CLASS_NAMES = ["bos", "system_end", "user_end", "system", "user"] CLASS_COLORS = ["#9E9E9E", "#FF9800", "#8BC34A", "#03A9F4", "#2196F3"] IDX_USER = 4 def plot_npz(path: Path) -> None: d = np.load(path) subj = d["subj"] other = d["other"] floor = d["floor"].astype(np.float32) if "probs" in d.files: probs = d["probs"] # (T, 5) — new format else: p_user = d["p_user"] # old format: only P(user) was stored probs = np.zeros((len(p_user), 5), dtype=np.float32) probs[:, IDX_USER] = p_user thr = float(d["threshold"]) sr = int(d["sr"]) hz = int(d["frame_rate"]) n_steps = len(floor) duration = n_steps / hz t_wav = np.arange(len(subj)) / sr t_frame = np.arange(n_steps) / hz fig, axes = plt.subplots(4, 1, figsize=(max(20, duration / 5), 10), gridspec_kw={"hspace": 0.05, "height_ratios": [1, 1, 2, 1]}) # waveforms for ax, wav, label in zip(axes[:2], [subj, other], ["subject", "other"]): ax.plot(t_wav, wav, lw=0.3, color="#555", alpha=0.8) ax.set_ylabel(label, fontsize=8) ax.set_xlim(0, duration) ax.tick_params(labelbottom=False) # all 5 class probabilities ax_prob = axes[2] for idx, (name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)): lw = 1.2 if idx == IDX_USER else 0.7 alpha = 0.9 if idx == IDX_USER else 0.7 ax_prob.plot(t_frame, probs[:, idx], lw=lw, color=color, alpha=alpha, label=name) ax_prob.fill_between(t_frame, probs[:, IDX_USER], alpha=0.1, color=CLASS_COLORS[IDX_USER]) ax_prob.axhline(thr, color="red", lw=0.8, ls="--", label=f"threshold={thr}") ax_prob.set_ylim(-0.05, 1.05) ax_prob.set_ylabel("class probs", fontsize=8) ax_prob.legend(fontsize=7, loc="upper right", ncol=3) ax_prob.set_xlim(0, duration) ax_prob.tick_params(labelbottom=False) # floor bit axes[3].step(t_frame, floor, where="post", lw=0.8, color="#4CAF50") axes[3].fill_between(t_frame, floor, step="post", alpha=0.2, color="#4CAF50") axes[3].set_ylim(-0.1, 1.1) axes[3].set_ylabel("floor bit", fontsize=8) axes[3].set_xlabel("Time (s)", fontsize=9) axes[3].set_xlim(0, duration) for ax in axes: for vl in np.arange(1, int(duration) + 1): ax.axvline(vl, color="#ddd", lw=0.4, zorder=0) axes[0].set_title(f"mimi_endpointer {path.stem} threshold={thr}", fontsize=9, loc="left", pad=3) out = path.with_suffix(".png") fig.savefig(out, dpi=100, bbox_inches="tight") plt.close(fig) print(f"saved → {out}") if __name__ == "__main__": npz_files = sorted(_HERE.glob("debug_pass*.npz")) if not npz_files: print("No debug_pass*.npz files found. Run with MIMI_DEBUG=1 first.") sys.exit(1) for f in npz_files: plot_npz(f)