"When are Iterative Gaussian Processes Reliably Accurate?"

by Wesley Maddox, Sanyam Kapoor, Andrew Gordon Wilson

Beyond First-Order Methods in ML Workshop at ICML 2021: https://sites.google.com/view/optml-icml2021/accepted-papers?authuser=0

Local link: https://wjmaddox.github.io/assets/iterative_gps_optml.pdf

Gaussian Process Intro

\begin{align} f \sim \mathcal{GP}(\mu_\theta, k_\theta(\mathbf{X}, \mathbf{X})) \end{align}\begin{align} y \sim \mathcal{N}(f, \sigma^2 I) \end{align}

While extremely flexible models, the catch with GPs is that the kernel function induces a matrix of size $n \times n$ that must be inverted for GP prediction and training.

Predictive equations: \begin{align} p(f(\mathbf{X}_\star) \mid \mathbf{X}_\star,& \mathcal{D}, \theta) = \mathcal{N}(\mu(\mathbf{X}_\star), \Sigma(\mathbf{X}_\star)), \label{eq:gp_post} \\ \mu(\mathbf{X}_\star) &= K_{\mathbf{X}_\star, \mathbf{X}} \widehat{K}_{\mathbf{X},\mathbf{X}}^{-1} \mathbf{y}~, \nonumber \\ \Sigma(\mathbf{X}_\star) &= K_{\mathbf{X}_\star, \mathbf{X}_\star} - K_{\mathbf{X}_\star, \mathbf{X}} \widehat{K}_{\mathbf{X},\mathbf{X}}^{-1} K_{\mathbf{X}_\star, \mathbf{X}}^\top ~, \nonumber \end{align} where $\widehat{K}_{\mathbf{X},\mathbf{X}} = K_{\mathbf{X},\mathbf{X}} + \sigma^2 \mathbf{I}$.

More specifically, this requires solving $n \times n$ systems of equations of the following form: \begin{align} \LARGE {\widehat{K}_{\mathbf{X},\mathbf{X}}} z = v \end{align} which naively takes $\mathcal{O}(n^3)$, thus limiting GP regression to generally no more than $\mathcal{O}(10,000)$ data points.

Iterative Gaussian Processes

Conjugate Gradients

Lanczos Decompositions

Investigating the Predictive Variances

While the MSE is pretty small, the NLL is much lower than the MSE would natively suggest. Indeed, investigating the NLL, we can immediately spot what's going on:

\begin{align} \LARGE \mathrm{NLL}(\sigma^2) := \frac{1}{2}\log \sigma^2 + \frac{1}{2\sigma^2}(\mu - y)^2~, \end{align}

Investigating the Predictive Means

Finally, we close by demonstrating the effect of the predictive means using the conjugate gradients tolerance threshold.