My SciPy ODE Solver Was Killing My Bayesian Inference: A Cosmologist’s Honest Account of Discovering Diffrax
Towards Data Science3191 字 (约 13 分钟)
87
Diffrax 用 JAX 实现的 ODE 求解器在 Cosmology 里将每次求解时间从 0.4 ms 降到 0.02 ms,梯度计算从 8 ms 降到 0.25 ms,整体提升 10‑倍以上,显著加速 Bayesian 推断。
入选理由:在 10⁵ 次 likelihood 评估中,SciPy ODE 仅 ODE 调用耗时 40 s,梯度 300 s;Diffrax 仅 24.8 s。
精选文章#Diffrax#JAX#ODE Solver#Bayesian Inference#Cosmology中文
