Add debugger precision analysis hooks#256
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a pluggable analysis hook system to the MojoDebugger, enabling customizable precision metrics during model debugging. It adds a suite of metric functions—including SNR, KL/JS divergence, Angular Error, and Inner Product Distortion—alongside a default hook that automatically categorizes operators to report relevant metrics. The review feedback suggests improving the numerical stability of KL/JS calculations by using torch.nn.functional.kl_div, handling edge cases in SNR calculations when signal power is zero, and enabling programmatic retrieval of hook results.
|
|
||
| signal_power = (ref_flat ** 2).sum().item() | ||
| noise_power = (diff_flat ** 2).sum().item() | ||
| snr_db = 10 * math.log10(max(signal_power / max(noise_power, 1e-30), 1e-30)) |
There was a problem hiding this comment.
The SNR calculation can return an extremely low value (e.g., -300 dB) if the signal_power is zero, which might be misleading. While the current logic is robust against division by zero, it might be clearer to return float('-inf') or a specific constant when the signal is effectively zero, or explicitly handle the case where both signal and noise are zero.
| for analysis_hook in self._analysis_hooks: | ||
| try: | ||
| analysis_hook(context) | ||
| except Exception as e: | ||
| logger.warning_once( | ||
| f"{_PREFIX} Analysis hook failed for {tag}: {e}. " | ||
| f"Analysis skipped, inference unaffected." | ||
| ) |
There was a problem hiding this comment.
The return values of the analysis hooks are currently ignored. While the default hook logs to the console, users providing custom hooks might expect a way to collect these metrics programmatically through the debugger instance. Consider storing the results of the hooks in a list or dictionary within MojoDebugger for later retrieval.
在
MojoDebugger的 compare 流程中加入精度看护指标,并通过analysis_funchook 出来。compare 仍然保持“当前 backend 输出 vs torch reference 输出”的逻辑,然后输出 A/B/C/D 类指标,包括 KL/JS/TopK、角度误差、IPD、SNR、RelL2、最大误差和 cosine similarity。同步更新了
debug_suite.md,并补充了指标函数和自定义 analysis hook 的测试。已在 h20 上通过 ruff、py_compile 和 debugger hook 相关 pytest 子集验证。