Chapter 15: Variational Inference

When the posterior \(P(\mathbf{z}|\mathbf{x})\) is intractable — as in almost all deep latent variable models — we replace exact inference with an optimisation problem: find the member of a tractable family \(q(\mathbf{z})\) closest to the true posterior in KL divergence.

1. The Intractability Problem

In a latent variable model with observations \(\mathbf{x}\) and latent variables \(\mathbf{z}\), we want the posterior:

\[ P(\mathbf{z} \mid \mathbf{x}) = \frac{P(\mathbf{x} \mid \mathbf{z})\,P(\mathbf{z})}{P(\mathbf{x})}, \qquad P(\mathbf{x}) = \int P(\mathbf{x} \mid \mathbf{z})\,P(\mathbf{z})\,d\mathbf{z} \]

The marginal \(P(\mathbf{x})\) requires integrating over all \(\mathbf{z}\) — exponential in dimension. Variational inference turns this into optimisation.

2. The ELBO: Full Derivation

Introduce a variational distribution \(q(\mathbf{z})\). Multiply and divide the log-evidence by \(q(\mathbf{z})\):

\[ \log P(\mathbf{x}) = \log \int P(\mathbf{x}, \mathbf{z})\,d\mathbf{z} = \log \int q(\mathbf{z})\frac{P(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\,d\mathbf{z} \]

Apply Jensen's inequality (\(\log\) is concave, so \(\log \mathbb{E}[X] \geq \mathbb{E}[\log X]\)):

\[ \log P(\mathbf{x}) \;\geq\; \underbrace{\mathbb{E}_{q(\mathbf{z})}\!\left[\log \frac{P(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right]}_{\mathcal{L}(q) \;=\; \text{ELBO}} \]

To see the gap, expand using the KL divergence:

\begin{align} \log P(\mathbf{x}) &= \mathbb{E}_{q}\!\left[\log \frac{P(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right] + \mathbb{E}_{q}\!\left[\log \frac{q(\mathbf{z})}{P(\mathbf{z} \mid \mathbf{x})}\right] \\ &= \mathcal{L}(q) + \mathrm{KL}(q(\mathbf{z}) \| P(\mathbf{z} \mid \mathbf{x})) \end{align}

Since \(\mathrm{KL} \geq 0\) with equality iff \(q = P(\mathbf{z}|\mathbf{x})\), maximising the ELBO is equivalent to minimising the KL to the true posterior. Rewrite the ELBO:

\[ \mathcal{L}(q) = \underbrace{\mathbb{E}_{q(\mathbf{z})}\!\left[\log P(\mathbf{x} \mid \mathbf{z})\right]}_{\text{reconstruction}} - \underbrace{\mathrm{KL}(q(\mathbf{z}) \| P(\mathbf{z}))}_{\text{regularisation}} \]

This decomposition is the VAE objective from Chapter 12: the first term rewards fitting the data; the second penalises the variational distribution for deviating from the prior.

3. ELBO Decomposition Diagram

log P(x) = ELBO + KL(q || posterior)ELBO L(q)KL(q || p(z|x))E_q[log p(x|z)]reconstruction-KL(q(z) || p(z))regularisationmaximise ELBO= minimise KL

4. Mean-Field Approximation & Coordinate Ascent VI

The mean-field family assumes the latent variables factorise:

\[ q(\mathbf{z}) = \prod_{j=1}^{J} q_j(z_j) \]

To find the optimal \(q_j^*(z_j)\) holding all others fixed, take the functional derivative of the ELBO with respect to \(q_j\) and set it to zero. Using the log-sum decomposition:

\[ \log q_j^*(z_j) = \mathbb{E}_{q_{-j}}\!\left[\log P(\mathbf{x}, \mathbf{z})\right] + \text{const} \]

This is the expectation of the log-joint over all variables except \(z_j\). The normalising constant is determined by requiring \(\int q_j^*(z_j)\,dz_j = 1\). Coordinate Ascent VI (CAVI) cycles through all \(j\), updating each \(q_j\) while holding others fixed — guaranteed to increase the ELBO at each step.

4.1 Application: Bayesian Linear Regression

With prior \(\mathbf{w} \sim \mathcal{N}(\mathbf{0}, \alpha^{-1}I)\) and likelihood \(\mathbf{y} \sim \mathcal{N}(\mathbf{X}\mathbf{w}, \beta^{-1}I)\), the log-joint is:

\[ \log P(\mathbf{y}, \mathbf{w}) = -\frac{\beta}{2}\|\mathbf{y} - \mathbf{X}\mathbf{w}\|^2 - \frac{\alpha}{2}\|\mathbf{w}\|^2 + \text{const} \]

This is quadratic in \(\mathbf{w}\), so the optimal \(q^*(\mathbf{w}) = \mathcal{N}(\mathbf{m}_N, \mathbf{S}_N)\) with:

\[ \mathbf{S}_N^{-1} = \alpha\mathbf{I} + \beta\mathbf{X}^\top\mathbf{X}, \qquad \mathbf{m}_N = \beta\mathbf{S}_N\mathbf{X}^\top\mathbf{y} \]

This matches the exact posterior — a special property of Gaussian models. For non-Gaussian models, VI provides an approximation.

5. Amortised Inference & VAEs

Classical VI optimises \(q(\mathbf{z})\) separately for each data point — \(O(N)\) separate optimisations. Amortised inference trains a neural network \(q_\phi(\mathbf{z}|\mathbf{x})\) (the encoder) to predict variational parameters directly from \(\mathbf{x}\):

\[ \mathcal{L}(\phi, \theta) = \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}\!\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right] - \mathrm{KL}\!\left(q_\phi(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z})\right) \]

This is exactly the VAE objective from Chapter 12. The encoder \(q_\phi\) and decoder \(p_\theta\) are trained jointly by maximising the ELBO via the reparameterisation trick: sample \(\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\varepsilon}\), \(\boldsymbol{\varepsilon} \sim \mathcal{N}(\mathbf{0}, I)\) to get differentiable gradients through the sampling step.

Expectation Propagation (brief overview)

EP minimises \(\mathrm{KL}(P(\mathbf{z}|\mathbf{x}) \| q(\mathbf{z}))\) — note the reversed KL compared to VI. This means EP is inclusive: it spreads \(q\) to cover all modes of the posterior, while VI tends to be mode-seeking, collapsing onto a single mode. EP is preferred when the true posterior is multi-modal.

6. Python Simulation: VI vs Exact Posterior

We implement CAVI for Bayesian linear regression with a quadratic feature basis and compare the VI posterior (mean-field Gaussian) against the exact Gaussian posterior. For this Gaussian model they agree exactly — illustrating that VI is exact when the posterior is Gaussian.

Python
script.py211 lines

Click Run to execute the Python code

Code will be executed with Python 3 on the server