ParallelTempering

class inference.mcmc.ParallelTempering(chains: list[inference.mcmc.base.MarkovChain])

A class which enables ‘parallel tempering’, a sampling algorithm which advances multiple Markov-chains in parallel, each with a different ‘temperature’, with a probability that the chains will exchange their positions during the advancement.

The ‘temperature’ concept introduces a transformation to the distribution being sampled, such that a chain with temperature ‘T’ instead samples from the provided posterior distribution raised to the power 1/T.

When T = 1, the original distribution is recovered, but choosing T > 1 has the effect of ‘compressing’ the distribution, such that any two points having different probability densities will have the difference between those densities reduced as the temperature is increased. This allows chains with higher temperatures to take much larger steps, and explore the distribution more quickly.

Parallel tempering exploits this by advancing a collection of markov-chains at different temperatures, with at least one chain at T = 1 (i.e. sampling from the actual posterior distribution). At regular intervals, pairs of chains are selected at random and a metropolis-hastings test is performed to decide if the pair exchange their positions.

The ability for the T = 1 chain to exchange positions with chains of higher temperatures allows it to make large jumps to other areas of the distribution which it may take a large number of steps to reach otherwise.

This is particularly useful when sampling from highly-complex distributions which may have many separate maxima and/or strong correlations.

Parameters

chains – A list of Markov-Chain objects (such as GibbsChain, PcaChain, HamiltonianChain) covering a range of different temperature levels. The list of chains should be sorted in order of increasing chain temperature.

advance(n, swap_interval=10)

Advances each chain by a total of n steps, performing swap attempts at intervals set by the swap_interval keyword.

Parameters
  • n (int) – The number of steps each chain will advance.

  • swap_interval (int) – The number of steps that are taken in each chain between swap attempts.

return_chains() list[inference.mcmc.base.MarkovChain]

Recover the chain held by each process and return them in a list.

Returns

A list containing the chain objects.

run_for(minutes=0, hours=0, swap_interval=10)

Advances all chains for a chosen amount of computation time.

Parameters
  • minutes (float) – Number of minutes for which to advance the chains.

  • hours (float) – Number of hours for which to advance the chains.

  • swap_interval (int) – The number of steps that are taken in each chain between swap attempts.

shutdown()

Trigger a shutdown event which tells the processes holding each of the chains to terminate.

ParallelTempering example code

Define a posterior with separated maxima, which is difficult for a single chain to explore:

from numpy import log, sqrt, sin, arctan2, pi

# define a posterior with multiple separate peaks
def multimodal_posterior(theta):
   x, y = theta
   r = sqrt(x**2 + y**2)
   phi = arctan2(y, x)
   z = (r - (0.5 + pi - phi*0.5)) / 0.1
   return -0.5*z**2  + 4*log(sin(phi*2.)**2)

Define a set of temperature levels:

N_levels = 6
temperatures = [10**(2.5*k/(N_levels-1.)) for k in range(N_levels)]

Create a set of chains - one with each temperature:

from inference.mcmc import GibbsChain, ParallelTempering
chains = [
    GibbsChain(posterior=multimodal_posterior, start=[0.5, 0.5], temperature=T)
    for T in temperatures
]

When an instance of ParallelTempering is created, a dedicated process for each chain is spawned. These separate processes will automatically make use of the available cpu cores, such that the computations to advance the separate chains are performed in parallel.

PT = ParallelTempering(chains=chains)

These processes wait for instructions which can be sent using the methods of the ParallelTempering object:

PT.run_for(minutes=0.5)

To recover a copy of the chains held by the processes we can use the return_chains method:

chains = PT.return_chains()

By looking at the trace plot for the T = 1 chain, we see that it makes large jumps across the parameter space due to the swaps:

chains[0].trace_plot()
_images/parallel_tempering_trace.png

Even though the posterior has strongly separated peaks, the T = 1 chain was able to explore all of them due to the swaps.

chains[0].matrix_plot()
_images/parallel_tempering_matrix.png

Because each process waits for instructions from the ParallelTempering object, they will not self-terminate. To terminate all the processes we have to trigger a shutdown even using the shutdown method:

PT.shutdown()