Chapter 15: Variational Inference
When the posterior \(P(\mathbf{z}|\mathbf{x})\) is intractable — as in almost all deep latent variable models — we replace exact inference with an optimisation problem: find the member of a tractable family \(q(\mathbf{z})\) closest to the true posterior in KL divergence.
1. The Intractability Problem
In a latent variable model with observations \(\mathbf{x}\) and latent variables \(\mathbf{z}\), we want the posterior:
The marginal \(P(\mathbf{x})\) requires integrating over all \(\mathbf{z}\) — exponential in dimension. Variational inference turns this into optimisation.
2. The ELBO: Full Derivation
Introduce a variational distribution \(q(\mathbf{z})\). Multiply and divide the log-evidence by \(q(\mathbf{z})\):
Apply Jensen's inequality (\(\log\) is concave, so \(\log \mathbb{E}[X] \geq \mathbb{E}[\log X]\)):
To see the gap, expand using the KL divergence:
Since \(\mathrm{KL} \geq 0\) with equality iff \(q = P(\mathbf{z}|\mathbf{x})\), maximising the ELBO is equivalent to minimising the KL to the true posterior. Rewrite the ELBO:
This decomposition is the VAE objective from Chapter 12: the first term rewards fitting the data; the second penalises the variational distribution for deviating from the prior.
3. ELBO Decomposition Diagram
4. Mean-Field Approximation & Coordinate Ascent VI
The mean-field family assumes the latent variables factorise:
To find the optimal \(q_j^*(z_j)\) holding all others fixed, take the functional derivative of the ELBO with respect to \(q_j\) and set it to zero. Using the log-sum decomposition:
This is the expectation of the log-joint over all variables except \(z_j\). The normalising constant is determined by requiring \(\int q_j^*(z_j)\,dz_j = 1\). Coordinate Ascent VI (CAVI) cycles through all \(j\), updating each \(q_j\) while holding others fixed — guaranteed to increase the ELBO at each step.
4.1 Application: Bayesian Linear Regression
With prior \(\mathbf{w} \sim \mathcal{N}(\mathbf{0}, \alpha^{-1}I)\) and likelihood \(\mathbf{y} \sim \mathcal{N}(\mathbf{X}\mathbf{w}, \beta^{-1}I)\), the log-joint is:
This is quadratic in \(\mathbf{w}\), so the optimal \(q^*(\mathbf{w}) = \mathcal{N}(\mathbf{m}_N, \mathbf{S}_N)\) with:
This matches the exact posterior — a special property of Gaussian models. For non-Gaussian models, VI provides an approximation.
5. Amortised Inference & VAEs
Classical VI optimises \(q(\mathbf{z})\) separately for each data point — \(O(N)\) separate optimisations. Amortised inference trains a neural network \(q_\phi(\mathbf{z}|\mathbf{x})\) (the encoder) to predict variational parameters directly from \(\mathbf{x}\):
This is exactly the VAE objective from Chapter 12. The encoder \(q_\phi\) and decoder \(p_\theta\) are trained jointly by maximising the ELBO via the reparameterisation trick: sample \(\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\varepsilon}\), \(\boldsymbol{\varepsilon} \sim \mathcal{N}(\mathbf{0}, I)\) to get differentiable gradients through the sampling step.
Expectation Propagation (brief overview)
EP minimises \(\mathrm{KL}(P(\mathbf{z}|\mathbf{x}) \| q(\mathbf{z}))\) — note the reversed KL compared to VI. This means EP is inclusive: it spreads \(q\) to cover all modes of the posterior, while VI tends to be mode-seeking, collapsing onto a single mode. EP is preferred when the true posterior is multi-modal.
6. Python Simulation: VI vs Exact Posterior
We implement CAVI for Bayesian linear regression with a quadratic feature basis and compare the VI posterior (mean-field Gaussian) against the exact Gaussian posterior. For this Gaussian model they agree exactly — illustrating that VI is exact when the posterior is Gaussian.
Click Run to execute the Python code
Code will be executed with Python 3 on the server