Bayesian inference, statistical physics, and the planted directed polymer problem
To cite this page
@misc{swpkim2024bayesian,
author={P. Kim, Sun Woo},
title={Bayesian inference, statistical physics, and the planted directed polymer problem},
year={2024},
howpublished={\url{https://sunwoo-kim.github.io/ko/activities/talks/240409-imperial/}},
note={Accessed: 2024-11-18}
}
차례
Introduction
Recently, there has been a lot of activity in casting problems in Bayesian inference problems as disordered stat-mech problems, then using tools of spin glasses to solve them, see Zdeborová 1511.02476. Here I wanted to give a quick introduction to this concept.
In Bayesian inference, we try to infer state $x$ given some measurements/data $y$. We assume some prior distribution $p(x)$ and a measurement model/likelihood $p(y \vert x)$ to infer a posterior distribution using Baye’s rule, $$ p(x \vert y) = \frac{p(y \vert x) p(x)}{p(y)}, $$ where the ’evidence’ $p(y) = \sum_x p(y \vert x) p(x)$ acts as the normalisation for the posterior. If the state space is large, then this is in general intractable. This is very similar to the partition function in statistical physics.
If $y$ came from the ’true’ state $x^*$, then the natural question to ask is, is inference possible. To quantify this question, for continuous variables, we may look at the mean-squared error, $$ \mathrm{MSE} = \mathop{\mathbb{E}}_{x \sim p(\cdot \vert y)} [(x-x^*)^2]. $$ Is there any way we can prove when inference is possible? And, can we have phase transitions, say, in the mean-squared error as signal-to-noise is varied?
We will consider an idealised scenario covered in the above reference, called the teacher-student scenario. The teacher generates some state of a system, $x^*$ from the ’true/teacher’s prior’ $p_\mathrm{T}(x^*)$. Then, the teacher generates some data/measurements $y$ given the state of the system via the teacher’s measurement model/likelihood, $p_\mathrm{T}(y \vert x^*)$. The teacher then hands the student the measurements $y$.
The student then assumes some prior distribution $p_\mathrm{S}(x)$ and measurement model $p_\mathrm{S}(y\vert x)$ to infer a posterior distribution $p_\mathrm{S}(x \vert y)$. Then we can consider the joint distribution, $$ \begin{aligned} p(x,y,x^*) & = p_\mathrm{S}(x \vert y) p_\mathrm{T}(y \vert x^*) p_\mathrm{T}(x^*) \\ & = \frac{p_\mathrm{S}(y \vert x)p_\mathrm{S}(x) p_\mathrm{T}(y \vert x^*) p_\mathrm{T}(x^*)}{p_\mathrm{S}(y)}. \end{aligned} $$ Note that $p(x^*)=p_\mathrm{T}(x^*)$, since we can integrate over $p_\mathrm{S}(x\vert y)$ wrt $x$ then $p_\mathrm{T}(y \vert x^*)$ wrt $y$. However, $p(x) \neq p_\mathrm{S}(x)$, since we can’t perform the integration over $y$ and $x^*$ first in general.
However, we can already immediately notice one thing: if the student’s model is equal to the teacher’s model, that is $S=T$, then $p(x,y,x^*)$ is symmetric with respect to $x \leftrightarrow x^*$ and therefore $x$ is distributed identically to $x^*$, and $p(x) = p_\mathrm{T}(x^*=x)$. The $\mathrm{S}=\mathrm{T}$ point is called the Bayes optimal point.
To make the connection to stat-mech clearer, let’s write out the student’s posterior in a suggestive way: $$ p_\mathrm{S}(x \vert y) = \frac{e^{-\beta H (x, y)}}{Z(y)}, $$ where $Z(y) = \sum_x e^{-\beta H(x, y)}$ and $$ -\beta H (x \vert y) = \ln q_\mathrm{S}(y \vert x) + \ln q_\mathrm{S}(x). $$ Here I wrote the unnormalised distributions ex. $q_\mathrm{S}(y \vert x) \propto p_\mathrm{S}(y \vert x)$ as normalisation is ensured by the partition function. We see that the posterior looks like the Boltzmann probability of disordered system (ex. spin glass), with the ‘data’ $y$ playing the role of disorder with $p_\mathrm{T}(y) = \sum_{x^*} p_\mathrm{T}(y \vert x^*) p_\mathrm{T}(x^*)$.
In the limit where the data contains no information about the true/teacher’s state, ex. when the data is ‘infinitely’ noisy, we see that $p_\mathrm{T}(y \vert x^*) \rightarrow p_\mathrm{T}(y)$ really becomes random disorder - this would be a situation where the student believes that there is some information in the data, but the teacher is actually supplying random noise.
When there is information about the true/teacher’s state in the data, this information is said to be ‘planted’ in the disorder. The distribution of $p_\mathrm{T}(y) = \sum_{x^*} p_\mathrm{T}(y \vert x^*) p_\mathrm{T}(x^*)$ is the ‘planted distribution’, and $p(x, y, x^*)$ is the ‘planted ensemble’.
This is all well and good, but is it actually useful? Let’s look at some examples, to see how existing disordered stat-mech problems map onto inference problems.
The planted random-bond Ising model
To make things more concrete, let’s consider a particular case of a model considered in the review. We consider $L^2$ people standing in a room in a square lattice formation. To each person, we randomly hand out a card that can be $S^*_i = \pm 1$ with equal probability. However, we don’t record this information. Then, we ask each neighbouring pair to reply with $1$ if they have the same cards or $-1$ if they have different cards. But there’s a twist - for each question the pair may lie with probability $\uppi$. We record the answer that each pair gives us as $J_{ij}$.
Then, given $\{J_{ij}\}_{ij}$, can we infer the ‘state’, i.e. the cards given to each person? As we have an equal-probability prior over the cards given, we have $p(S) \propto 1$. The ’likelihood’ of answer $J_{ij}$ given cards $S_i$ and $S_j$ is $$ p(J_{ij} \vert S_i, S_j) = \begin{cases} 1 - \uppi & J_{ij} = S_i S_j \\ \uppi & J_{ij} = -S_i S_j. \end{cases} $$ As $J_{ij}$ is an Ising variable we can massage this to look like a Boltzmann weight. If we consider $$ p(J_{ij} \vert S_i, S_i) = N e^{\beta J_{ij} S_i S_j}, $$ Then from normalisation over $J_{ij} \in {1, -1}$ we have $N = e^{\beta} + e^{-\beta}$, and $$ \uppi = \frac{e^{-\beta}}{e^{\beta} + e^{-\beta}}. $$ Therefore if each pair lies half of the time, $\uppi = 1/2$, $\beta = 0$, and if they never lie, $\uppi = 0$, $\beta \rightarrow \infty$.
Then the likelihood over all answers is $$ p(J \vert S) = \prod_{\langle i,j \rangle} p(J_{ij} \vert S_i, S_j) \propto \exp\left(\beta \sum_{\langle i, j\rangle} J_{ij} S_i S_j \right). $$ Since the prior is uniform, the posterior is $$ p(S \vert J) = \frac{q(J \vert S)}{Z(J)} = \frac{\exp\left(\beta \sum_{\langle i, j \rangle}J_{ij} S_i S_j\right)}{Z(J)}, $$ where $$ Z(J) = \sum_{S} \exp\left(\beta \sum_{\langle i, j\rangle} J_{ij} S_i S_j \right). $$ is the partition function for a random-bond Ising model, and the posterior is its Boltzmann probability!
Let’s now consider the general teacher-student scenario, where the student’s model belief of the rate of lying $\beta_\mathrm{S}(\uppi_\mathrm{S})$ can differ from the true/teacher’s $\beta_\mathrm{T}(\uppi_\mathrm{T})$. $$ \begin{aligned} p(S, J, S^*) & = p_\mathrm{S}(S \vert J) p_\mathrm{T}(J \vert S^*) p_\mathrm{T}(S^*) \\ & \propto \frac{\exp\left( \beta_\mathrm{S} \sum_{\langle ij \rangle} J_{ij} S_i S_j\right)}{Z_\mathrm{S}(J)} \exp\left( \beta_\mathrm{T} \sum_{\langle ij \rangle} J_{ij} S^*_i S^*_j\right). \end{aligned} $$ Then consider the observable $S_i S^*_i$. This is $1$ if they are aligned (if the inferred card is same as the true value of the card), and $-1$ if they are anti-aligned, so it is a measure of possibility of inference. $$ \mathop{\mathbb{E}}[S_l S^*_l] \propto \sum_{S^*} \sum_{S} \sum_{J} S_l S^*_l \frac{\exp(\beta_\mathrm{S} \sum_{\langle ij\rangle}J_{ij}S_i S_j) \exp(\beta_\mathrm{T} \sum_{\langle ij\rangle}J_{ij}S^*_i S^*_j)}{Z_\mathrm{S}(J)}. $$ Since we are summing over $J$, we are free to redefine $J$ such that $J_{ij} S^*_i S^*_j = \tilde{J}_{ij}$. Similarly we can redefine $S$ such that $S_i S^*_i = \tilde{S}_i$. Inside the partition function we can also redefine $S’$ inside the sum. Therefore $$ \begin{aligned} \mathop{\mathbb{E}}[S_l S^*_l] & \propto \sum_{S^*} \sum_{\tilde{S}} \sum_{\tilde{J}} \tilde{S}_l \frac{\exp(\beta_\mathrm{S} \sum_{\langle ij\rangle} \tilde{J}_{ij} \tilde{S}_i \tilde{S}_j) \exp(\beta_\mathrm{T} \sum_{\langle ij\rangle}\tilde{J}_{ij})}{\sum_{\tilde{S}’} \exp\left( \beta_\mathrm{S} \sum_{\langle ij\rangle} \tilde{J}_{ij} \tilde{S}’_i \tilde{S}’_j\right)} \end{aligned} $$ So now $S^*$ has disappeared, and we the true distribution is the ‘ferromagnetic configuration’. So $$ \begin{aligned} \mathop{\mathbb{E}}[S_l S^*_l] & \propto \sum_{\tilde{S}} \sum_{\tilde{J}} \tilde{S}_l \frac{\exp(\beta_\mathrm{S} \sum_{\langle ij\rangle} \tilde{J}_{ij} \tilde{S}_i \tilde{S}_j) \exp(\beta_\mathrm{T} \sum_{\langle ij\rangle}\tilde{J}_{ij})}{\sum_{S’} \exp\left( \beta_\mathrm{S} \sum_{\langle ij\rangle} \tilde{J}_{ij} S’_i S’_j\right)} \\ & = \sum_{\tilde{J}} \frac{\sum_{\tilde{S}} \tilde{S}_l \exp(\beta_\mathrm{S} \sum_{\langle ij\rangle} \tilde{J}_{ij} \tilde{S}_i \tilde{S}_j) \exp(\beta_\mathrm{T} \sum_{\langle ij\rangle}\tilde{J}_{ij})}{\sum_{S’} \exp\left( \beta_\mathrm{S} \sum_{\langle ij\rangle} \tilde{J}_{ij} \tilde{S}’_i \tilde{S}’_j\right)}, \end{aligned} $$ which is the magnetisation in the random-bond Ising model with $p(\tilde{J}) \propto \exp(\beta_\mathrm{T} \sum_{\langle ij\rangle}\tilde{J}_{ij})$. So the paramagnetic (PM) state corresponds to the failure of inference, and the ferromagnetic (FM) state corresponds to the success of inference.
By the way, in the spin-glass literature, what we’ve done here is called a ‘gauge transformation’. And, there is a special ‘solvable line’ in this model, called the Nishimori line, where $p(J) \sim Z(J)$. The basic idea was that $p(J)$ just becomes another replica in the problem and so simplifies the calculation, but it turns out that the point where the Nishimori line crosses the phase boundary is very interesting. But this exactly corresponds to the Bayes optimality condition, $\beta_\mathrm{S} = \beta_\mathrm{T}$. Usually, people show that for particular choices for $\uppi$, we can gauge-transform $p(J)$ to look like $Z(J)$. Here, we did it the other way around.
People have already studied the phase diagram of the RBIM, so we have the phase diagram for free. Let’s look at the phase diagram of the RBIM (for example see Gruzberg 0007254v2).
The $p=0$, $T_\mathrm{c} \approx 2.27$ corresponds to the clean 2D Ising transition. For us, $T$ corresponds to $\beta_\mathrm{S}^{-1}$ and $p = \uppi_\mathrm{T}$. Thinking of $\beta_{S,T}$ as the ‘signal strength’, we can redraw the phase diagram in the $\beta_\mathrm{T}-\beta_\mathrm{S}$ plane. We know that the Nishimori line is the $45^\circ$ line in this plane.
We see that in this picture, no matter what $\beta_S$ the student chooses, we need sufficient teacher’s signal strength $\beta_\mathrm{T} > \beta^\mathrm{c}_\mathrm{T}$ in order for inference to be possible. On the other hand, if the student assumes that the signal strength is too low, $\beta_\mathrm{S} < \beta^\mathrm{c}_\mathrm{S}$, then again inference is not possible.
The planted directed polymer problem
So now the world is our oyster. We can start cooking up planted versions of spin-glass/disordered stat-mech models, and see what inference problem they map on to. We were particularly interested in looking at a class of inference problems involving hidden Markov models. This is where the prior follows a Markov process, $$ p(x_{1:t})=\left(\prod_{\tau=2}^tp(x_\tau|x_{\tau-1})\right)p(x_1), $$ and measurements $y_{t}$ conditioned on the state are generated through some measurement model $p(y_{t} \vert x_{t})$ at every timestep. As before, we could calculate a posterior for the entire trajectory, $p(x_{1:t} \vert y_{1:t})$, but this is usually intractable computationally. Instead, we can also consider a filtering task, where the state at the current timestep is inferred from data from all previous timesteps, i.e. $p(x_{t} \vert y_{1:t})$.
Concretely, we consider the case when the states are locations in space, and the Markov process is a random walk on that space. A simple example of a ‘kernel’ for this process is $$ p(x_{t+1} \vert x_{t}) = \frac{1}{2} \delta_{x_{t+1},x_{t}} + \frac{1}{4} \delta_{x_{t+1},x_{t}\pm 1}. $$ An example of such a random walk in 1D is shown below.
At every timestep $t$, we will take an ‘image’ of the walker specified by pixel values at positions $x$, $$ \phi_{x,t} = \epsilon_\mathrm{T} \delta_{x, x^*_{t}} + \psi_{x,t}, $$ which has a peak at the true location of the walker with ‘signal strength’ $\epsilon_\mathrm{T}$, and $\psi_{x,t}$ is iid random Gaussian noise $\psi_{x,t} \sim \mathcal{N}(0, \sigma_\mathrm{T}^2)$. An example of set of all images from each timestep for the above random walk is shown below.
Then, we’d like to infer the true position of the walker at the current time, $x^*_{t}$ or the true trajectory $X^* := (x^*_\tau)_{\tau=1}^T$ given the images at all times. We call it the ‘planted directed polymer problem’, as it is related to the directed polymer in a random medium from stat-mech. Let’s see how.
Let the whole image at time $t$ be $\boldsymbol{\phi}_{t} := (\phi_{x,t})_x$. Then assuming signal strength $\epsilon_\mathrm{S}$ and noise strength $\sigma_\mathrm{S}$, the likelihood for observing an image $\boldsymbol{\phi}_{t}$ given that the walker’s position is $x_{t}$ is $$ \begin{aligned} p_\mathrm{S}(\boldsymbol{\phi}_{t} \vert x_{t}) & = \prod_{x’_{t}} \frac{1}{\sqrt{2 \pi \sigma_\mathrm{S}^2}} \exp \left[ \frac{-(\phi_{x’_{t}, t} - \epsilon_\mathrm{S} \delta_{x_{t}, x’_{t}})^2}{2 \sigma_\mathrm{S}^2} \right] \\ & = \exp\left(\left[\epsilon_\mathrm{S} \phi_{x_{t},t} - \epsilon^2/2\right]/\sigma_\mathrm{S}^2\right)\pi_\mathrm{S}(\boldsymbol{\phi}_{t}), \end{aligned} $$ where $\pi_\mathrm{S}(\boldsymbol{\phi}_t)$ denotes the Gaussian measure. Denoting $X := x_{1:t}$ and $\Phi := \boldsymbol{\phi}_{1:t}$, the likelihood for the whole trajectory is $$ p_\mathrm{S}(\Phi \vert X) \propto \exp\left(\frac{\epsilon_\mathrm{S}}{\sigma_\mathrm{S}^2}\sum_{t} \phi_{x_{t}, t} \right)\pi_\mathrm{S}(\Phi). $$ As the $\pi_\mathrm{S}(\Phi)$ term is there for any value of $X$, we can ignore it, since we can just normalise the posterior again. So the student’s posterior for the whole path is $$ p_\mathrm{S}(X \vert \Phi) = \frac{q_\mathrm{S}(X \vert \Phi)}{Z_\mathrm{S}(\Phi)}, $$ where the unnormalised posterior is $$ q_\mathrm{S}(X \vert \Phi) = \exp\left(\frac{\epsilon_\mathrm{S}}{\sigma_\mathrm{S}^2}\sum_{t=1}^T \phi_{x_{t},t} + \sum_{t=2}^{T} \ln q_\mathrm{S}(x_{t} \vert x_{t-1}) \right) q_\mathrm{S}(x_1), $$ and $Z_\mathrm{S}(\Phi) = \sum_X p_\mathrm{S}(X \vert \Phi)$.
Let us assume that the kernel $q_\mathrm{S}(x_t \vert x_{t-1})$ only depends on the distance between $x_t$ and $x_{t-1}$, and is maximised when the distance is zero. Then, we write $\ln q_\mathrm{S}(x_{t} \vert x_{t-1}) = f(\sqrt{a}(x_{t} - x_{t-1}))$ where $a$ is the lattice constant. Expanding, we find $$ \begin{aligned} q_S(X \vert \Phi) = \exp\Bigg( \sum_{\tau=1}^{t} \left[-\frac{a}{4 \nu_\mathrm{S}} \left[x_\tau - x_{\tau-1}\right]^2 + \frac{\epsilon_\mathrm{S}}{\sigma_\mathrm{S}^2} \phi_{x_\tau, \tau} \right] +\mathcal{O}(a^{3/2})\Bigg) q_S(x_1), \end{aligned} $$ where $\nu_\mathrm{S} = \frac{1}{2 \lvert f’’_\mathrm{S}(0) \rvert}$ is the ‘width’ of the kernel. We see that the unnormalised posterior now looks like the Boltzmann weight for the directed polymer in a random environment. For those who are unfamiliar, it is a well studied stat-mech model for configurations of a polymer. Interpreting time as another spatial dimension $t=y$, it is ‘directed’ in that the polymer cannot loop around itself, since $x$ is a single-valued function of $y$.
The first term in the sum $\sim (x_\tau - x_{\tau-1})^2$ corresponds to the ’elastic energy’ of the polymer, where you pay a price whenever you stretch out the polymer. The second term $\sim \phi_{x_t, t}$, where $\phi_{x_t, t}$, usually random iid noise in the directed polymer setting, is a ‘random environment’ that the polymer is placed in.
The interesting phenomena in the directed polymer is that even though you need to pay energy to stretch out the polymer, it might be worth it to go through a particular location, as long as the potential there is deep enough. This creates a ‘roughening’ of the polymer (you can read more about it in Bhattacharjee 0402117v1, for example).
For us, $\phi_{x_t, t}$ is not iid, but instead has information about the true location of the walker ‘planted’ in it. In recent work, we studied this problem in detail in 1D and also on the tree. Read more here: P. Kim 2404.07263.