Annealing Importance Sampling

We discuss an algorithm that reaps of the benefits of both MCMC and IS, Annealing Importance Sampling. This blog builds on my previous blogs on Importance Sampling (IS) and Monte Carlo Markov Chain (MCMC). First, I’ll discuss why the marginal likelihood is relevant via example. Then I’ll introduce the two main ideas behind Annealing Importance Sampling and end with a discussion on understanding those two main ideas.

Before we start, I’d like to thank Jan-Willem van de Meent for his lectures in his Advance Machine Learning Class for PhD students at Northeastern University. The images shown are either inspired by or directly from his lectures. I’d like to add that more details on Annealing Importance Sampling can be found on this paper by Radford M. Neal.

As previously discuss, the benefit that MCMC has over Importance Sampling is its transition kernel that allows us to stay within the target density once we already have a sample from that target density. However, I mentioned in my previous blog, MCMC:Metropolis-Hastings, that Importance Sampling gives us something MCMC does not, an estimate of the marginal likelihood, p(y). So now we answer the question, what is the benefit of having an estimate of the marginal likelihood? We discuss with an example.

Why is the Marginal Likelihood Relevant?

Suppose that we are doing clustering on the data points shown in the figure above, and we ask the question, “How many clusters K should I use?” In the first example, we try to fit 2 clusters to the data, in the second example, we try 3 clusters, and lastly, we try 10 clusters. If we look at the likelihood of the data given the parameters, then the last example has a really high likelihood, p(y|\theta), but the first example has low likelihood. Even though the last example has high likelihood, we also know that the clusters are over-fitting. The question is, can we describe in math what over-fitting looks like?

The intuition is: If we randomly drew new clusters for the data, how likely is it that the clusters would be on top of the data? For the last example, the probability is next to zero. On the other hand, for the first example, some of the points would probably fall into the new clusters (shown in blue). We can formalize that by looking at the marginal likelihood, which is p(y|k) (probability of y given the number of clusters k).

K^* = \underset{{k \in {1, ... , K^{max}}}}{\text{argmax}} p(y|k) = \underset{k}{\text{argmax}} \int d\theta p(y|\theta)p(\theta |k)

where p(\theta |k), describes the probability of the positions of the clusters given the number of clusters. We are going to go from K=1 cluster to the max number of clusters, and for each of those, we are going to compute the integral. We can think of the marginal likelihood as the Best Average Fit for the data!

Comparing Importance Sampling with MCMC

So now we know why having the marginal likelihood is kind of a big deal. However, Importance Sampling has an annoying property that makes it hard to generate a bunch of good samples. Consider the example we just discussed. Importance Sampling would basically randomly guess where to put clusters, which wouldn’t be efficient. MCMC, on the other hand, has this nice property where if we have a bad sample, we have a transition kernel that can give us a new sample. If the new sample is better, then we keep it. So if we keep getting better and better samples, we are more likely to get more good samples than when we randomly guess (like in Importance Sampling).

So the question now is “Why not both MCMC and IS?

Annealing Importance Sampling

Idea 1 (Importance Sampling)

Imagine that we have we have two tiny clusters (above on the left), and we decide to sample uniformly from the shaded square within the plot (right). It would actually take a very long time to get sufficient samples that lie within the two tiny clusters. So the question we ask here is, “Can we make this problem a little bit easier?” How can we make it easier?

Well, imagine that we take the small cluster and make it a little broader. Then, we take those broader clusters and make those broader, and we keep going until we get a large single cluster.

We end up with something that can easily sample from! But how can we make this happen… and how does really solve our initial problem? Well, what if we could get samples from the broadest cluster and then refine those samples to get samples from a less broad cluster… and keep going until our samples are so refined and come from the original two small clusters?

That seems like a good idea, right!?

Idea 1: Break the problem into simple problems. Specifically, sample via intermediates of distributions.

Let’s say that our end goal is get some joint probability, p(y, \theta) of some data, y, and latent variables, \theta (which can be model parameters). However, we also want a really easy distribution to start with. Well, the easiest probability distribution we can is use is the prior, p(\theta). It’s easy because sampling from the prior doesn’t require any inference.

So now the question is, “what can we do so that we start at the easy distribution and smoothly transition to the final (and hard) joint distribution?”

Let’s think about this! The relationship between the prior and the joint probability is that the joint probability is equal to the likelihood times the prior:

p(y, \theta) = p(\theta \mid y) p(\theta)

So what really let’s us start from the prior and end up to the joint probability is the likelihood! We basically remove the likelihood from the joint probability and then gradually add it back in.

One way we can do this is by introducing a number, \beta_n, where we use it to get the n^{th} intermediate distribution:

p(y  \mid \theta)^{\beta_n} p(\theta)

If \beta_n = 0, what do have? We are left with just the prior. If \beta_n =1, we have the full joint distribution. So to get a sequence of intermediate distributions, we pick a sequence of values that smoothly go from 0 to 1.

Idea 2 (Transition Kernels)

Now that we know how to get intermediate distributions, let’s talk about how we can use them.

Before we move on, let’s think about the overall goal again. We started off with an example showing that we want to learn clusters for some data, y. If we are using Gaussian distributions to represent our clusters, then our parameters, \theta, would be the mean and standard deviation, \theta = {\mu, \sigma^2 }. Therefore we are trying to learn the cluster means and standard deviations to best represent the data.

Initialization

Let’s begin by saying that we will sample a first sample from some proposal distribution: \theta_0^s \sim q(\cdot) (using a zero to signify the first sample we have). The sample is going to have an importance weight: w_0^s = \frac{\gamma(x_0^s)} {q(x_0^s)}, (just like in Importance Sampling).

Now image that we have two samples (as shown in the figure above in green and within the prior distribution, \gamma_0(\cdot). So now what we want to do is evaluate those samples relative to the next (and new) density, \gamma_{n+1}(\cdot).

To do that, we are doing to use some form of transition kernel (just like in MCMC) to move the samples within the new density. We can move it several times. This gives us a new sample within the new density.

\theta_n^s \sim k_{n-1}(\theta_n \mid \theta_{n-1})

However, this means that we have a new importance weight for these new samples. The new weight, w_n^s, is updated as follows (I will derive it in a moment):

w_n^s = \frac{\gamma_n(\theta_n^s)}{\gamma_{n-1}(\theta_n^s)} w_{n-1}^s

So you can imagine now doing that for all of the intermediate distributions until we get to the final one. The algorithm is rather straightforward from here: 1) sample from some initial proposal, 2) get an importance weight, 3) evaluate the sample at the next density which gives you a new importance weight, 4) move the sample around the new density, then 5) repeat to the following density.

Understanding Idea 1: Deriving New Importance Weights

So how do we end up with the equation of the new weight evaluated at the next density? To answer this, let’s begin with a valid importance weight. Let’s assume that we sampled from a normalized density, x \sim \pi(x)_{n-1}, and we have an unnormalized density at the next step, \gamma_{n}(x). Then a reasonable importance weight would be:

w_n = \frac{\gamma_n(x)}{\pi_{n-1}(x)} \qquad x \sim \pi_{n-1}(x)

We know that the relationship between a normalized density and an unnormalized density is the normalization constant, Z, where \pi_n = \frac{\gamma_n(x)}{Z_n}. Then one way we can rewrite w_n is:

w_n = \frac{\gamma_n(x)}{\pi_{n-1}(x)} = \frac{\gamma_n(x)}{\gamma_{n-1}(x)} Z_{n-1} \qquad x \sim \pi_{n-1}(x)

Now imagine instead of sampling from \pi_{n-1}(x), we sample from \pi_{n-2}(x). Then we’d have:

w_n = \frac{\gamma_n(x)}{\pi_{n-2}(x)} \qquad x \sim \pi_{n-2}(x)

Then if we use the multiplying by 1 trick on this, 1 = \frac{\gamma_{n-1}(x)}{\gamma_{n-1}(x)}. We have

w_n = \frac{\gamma_n(x)}{\pi_{n-2}(x)} =   \frac{\gamma_n(x)}{\gamma_{n-1}(x)} \frac{\gamma_{n-1}(x)}{\pi_{n-2}(x)} \qquad x \sim \pi_{n-2}(x)

Notice how \frac{\gamma_{n-1}(x)}{\pi_{n-2}(x)} = w_{n-1} ! Therefore we can write:

w_n = \frac{\gamma_n(x)}{\gamma_{n-1}(x)} w_{n-1}  \qquad x \sim \pi_{n-2}(x)

This is really nice for us because now all we have to do to get a new weight is relate it to the previous weight.

Understanding Idea 2: Transition Kernels

So now we talk about combining Importance Sampling with MCMC.

We start with an importance sampler which means we sample from a proposal and then we get an importance weight. Now, MCMC let us move the sample around using a transition kernel, \kappa(x' \mid x) (assuming it satisfies detailed balance).

w = \frac{\gamma(x)}{q(x)} \qquad x \sim q(x) \qquad x' \sim \kappa(x' \mid x)

We know that we update the importance weight relative to a new density, but do we need to update our weight after moving our sample around in the new density? Well, let’s think a little bit.

w' = \:?

Let’s remember what an importance weight looks like: w = \frac{\gamma(x)}{q(x)}. So getting some \gamma and q for our w' would mean considering both x and x'. Let’s expand on that starting with the proposal.

The proposal mechanism here is really doing two things at once, getting x and x', \tilde{q}(x, x'). This is done as shown before, sampling from some proposal and then using a transition kernel.

\tilde{q}(x,x') = q(x)\kappa(x'\mid x)

The unnormalized density also considers both x and x', \tilde{\gamma}(x, x').

\tilde{\gamma}(x,x') = \gamma(x')\kappa(x\mid x')

So then our new weight would look like:

w' = \frac{\tilde{\gamma}(x,x')}{\tilde{q}(x,x')} = \frac{\gamma(x')\kappa(x\mid x')}{q(x)\kappa(x'\mid x)}

If our kernel satisfies detailed balance, \gamma(x)\kappa(x'\mid x) = \gamma(x')\kappa(x\mid x') … notice how if we rewrite our detailed balance equation to solve for \gamma(x), we will have \frac{\gamma(x')\kappa(x\mid x')}{\kappa(x'\mid x)} = \gamma(x). Our new weight now looks like:

w' = \frac{\gamma(x')\kappa(x\mid x')}{q(x)\kappa(x'\mid x)} = \frac{\gamma(x)}{q(x)} = w

which equals to the original importance weight we calculated when we evaluated the sample on the new density!

So back to the original question here: do we need to update our weight after moving our sample around in the new density? The answer is nope! This is neat because we use MCMC on each distribution while preserving the importance weights!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s