{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Adaptive rejection sampling\n", "\n", "Rejection sampling (RS) is a useful method for sampling intractable distributions. It defines an envelope function which upper-bounds the target unnormalised probability density to be sampled. It then proceeds to sample points in the area under the envelope, rejecting those points which fall above the target and accepting the rest. The accepted points are independent and identically distributed samples from the target distribution. There are two important issues with RS. The first is that if the envelope is a very loose upper bound, then most samples will be rejected and the scheme will be slow. The second is that for rejection sampling to work, we must be certain that the envelope is an upper bound to the target, which in practice may be a challenging task.\n", "\n", "Adaptive rejection sampling (ARS) {cite}`gilks1992ars` is an efficient method for sampling log-concave targets, which deals with both of these issues. It is origially defined for univariate distributions, but can also be extended to multivariate distributions via Gibbs sampling.{cite}`bishop2006PRML` ARS maintains an envelope which adapts as more points are sampled, becoming a progressively tighter bound to the target, thereby avoiding the inefficiency of regular RS. Further, the way that ARS constructs this envelope guarantees that the envelope is in fact an upper bound to the target, which sidesteps the second difficulty described above." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/w_/5zj48w1d0xb7ycgdm6pk40v00000gn/T/ipykernel_98848/165326261.py:5: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n", " set_matplotlib_formats('pdf', 'svg')\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from IPython.display import HTML, set_matplotlib_formats\n", "set_matplotlib_formats('pdf', 'svg')\n", "#css_style = open('../../../_static/custom_style.css', 'r').read()\n", "#HTML(f'')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## An adaptive envelope function\n", "\n", "Suppose we wish to sample from a log-concave univariate distribution with unnormalised distribution function $f$. Whereas RS defines a fixed envelope, ARS will define an envelope $g_u$ that upper bounds $f$ and adapts its shape as the sampling procedure progresses. By adapting its shape, using the infromation that $f$ is log-concave, the envelope reduces the probability of future rejections. In addition the envelope function, ARS can use an optional function $g_l$ which lower bounds $f$, called the squeezing function. The squeezing function can be used to avoid evaluating $f$ in the rejection step, which can be especially useful if $f$ is computationally expensive to evaluate.\n", "\n", "Given the an ordered set of points $x_1 < x_2 < ... < x_K$, ARS defines the log-envelope $\\log ~ g_u$ to be the minimum over the tangents to $h = \\log f$ at these points. The log-squeezing function is defined to be the piecewise linear function which joins the points $(x_k, h(x_k))$ inside the interval $[x_1, x_K]$ and is equal to $-\\infty$ outside this innterval. Examples of envelope and squeezing functions are shown below.\n", "\n", "
\n", " \n", "**Definition (Abscissa set, envelope function and squeezing function)** Let $f(x)$ be a univariate log-concave function, with non-zero domain $D = \\{x : f(x) > 0\\}$. An abscissa set $T_K$ is an ordered set of points in $D$ such that\n", " \n", "$$T_k = \\{x_1 < x_2 < ... < x_K\\}.$$\n", " \n", "The envelope function $g_u(x)$ defined by $T_k$ is\n", " \n", "$$g_u(x) = \\min_{k} g_{u, k}(x)$$\n", " \n", "where $g_{u, k}(x), k = 1, 2, ..., K$ are piecewise exponential functions such that\n", " \n", "$$g_{u, k}(x_k) = f(x_k) \\text{ and } g_{u, k}'(x_k) = \\log f'(x_k).$$\n", " \n", "The squeezing function $g_l(x)$ defined by $T_k$ is\n", " \n", "$$g_l(x) = \\begin{cases} \\min_{k} g_{l, k}(x) & \\text{ if } x_1 \\leq x \\leq x_k, \\\\ 0 & \\text{ otherwise.} \\end{cases}$$\n", " \n", "where $g_{u, k}(x), k = 1, 2, ..., K - 1$ are piecewise exponential functions such that\n", " \n", "$$g_{l, k}(x_k) = f(x_k) \\text{ and } g_{l, k}(x_{k+1}) = \\log f(x_{k+1}).$$\n", " \n", "
\n", "
\n", "\n", "Below are functions implementing the necessary calculations to determine the envelope and squeezing functions, all of which are cheap operations. The functions take in points at input locations $x$ and corresponding $h$ and $h'$ values and carry out computations in log-space, before exponentiating the result at the end." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def g_u(x, xs, hs, dhdxs):\n", " \n", " z, _ = compute_points_of_intersection_and_intercepts(xs, hs, dhdxs)\n", " i = np.searchsorted(z, x)\n", " \n", " return np.exp(dhdxs[i] * (x - xs[i]) + hs[i])\n", " \n", "\n", "def g_l(x, xs, hs):\n", " \n", " if all(x < xs) or all(x > xs):\n", " return 0.\n", " \n", " else:\n", " i = np.searchsorted(xs, x)\n", " m = (hs[i] - hs[i-1]) / (xs[i] - xs[i-1])\n", " \n", " return np.exp(hs[i-1] + (x - xs[i-1]) * m)\n", "\n", "\n", "def compute_points_of_intersection_and_intercepts(x, h, dhdx):\n", " \n", " # y-intercepts c of envelope function line segments, intersection points z\n", " c = h - dhdx * x\n", " z = (c[1:] - c[:-1]) / (dhdx[:-1] - dhdx[1:])\n", " \n", " return z, c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's define a log unnormalised Gaussian log density, and use this to illustrate the envelope and squeezing function defined by an abcissa set with three points." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def log_gaussian(mean, variance):\n", " return lambda x : (- 0.5 * (x - mean) ** 2 / variance, - (x - mean) / variance)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [ "hide-input", "center-output" ] }, "outputs": [ { "data": { "application/pdf": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-07-31T20:41:03.419768\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# The log unnormalised density to illustrate\n", "log_prob = log_gaussian(0., 1.)\n", "\n", "# Points in the abcissa set and corresponding log-probabilities and gradients\n", "xs = np.array([-1., 0.1, 1.5])\n", "hs, dhdxs = log_prob(xs)\n", "\n", "# Locations to plot the log unnorm. density and envelope/squeezing functions\n", "x_plot = np.linspace(-2, 2, 200)\n", "log_probs = [log_prob(x)[0] for x in x_plot]\n", "gu = [g_u(x, xs, hs, dhdxs) for x in x_plot]\n", "gl = [g_l(x, xs, hs) for x in x_plot]\n", "\n", "# Plot the log unnormalised density, the envelope and squeezing functions\n", "plt.figure(figsize=(6, 3))\n", "plt.scatter(xs, hs, color='k', zorder=3)\n", "plt.plot(x_plot, log_probs, color='black', label='$\\log~f = h$')\n", "plt.plot(x_plot, np.log(gu), color='red', label='$\\log~g_u$')\n", "\n", "# Handle the case of negatively infinite gl, for plotting presentation\n", "floored_log_gl = np.log(np.maximum(np.array(gl), np.ones_like(gl) * 1e-9))\n", "plt.plot(x_plot, floored_log_gl, color='green', label='$\\log~g_l$')\n", "\n", "# Plot formatting\n", "plt.xlim([-2, 2])\n", "plt.ylim([-3, 1])\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.xlabel('$x$', fontsize=18)\n", "plt.ylabel('$\\log~f(x)$', fontsize=18)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adaptive rejection sampling\n", "\n", "As observed above, by vitrue of the log-concavity of $f$, the envelope and squeezing functions defined in this way are upper and lower bounds to $f$. If we then sample a point at random from the area under $g_u$, and this point also happens to be in the area under $f$, then the point is uniformly distributed in the area under $f$, and is an exact sample from the target distribution. Further, if the point happened to lie in the area under the squeezing function $g_l$, it is certain to also lie in the area under $f$, so we need not check this latter condition explicitly. This shortcut is particularly useful if the function $f$ is expensive to evaluate, because it avoids some of the evaluations of $f$. Combining these checks, we arrive at the ARS algorithm below.\n", "\n", "
\n", " \n", "**Algorithm (Adaptive Rejection Sampling)** Given a univariate un-normalised probability density $f(x)$, perform the following initialisation, sampling and update steps:\n", " \n", "1. Initialise an abscissa set $T_k$, such that $f'(x_1) > 0$ and $f'(x_k) < 0$, as well as the corresponding envelope and squeezing functions $g_u$ and $g_l$. This can be efficiently achieved by starting from an initial guess and stepping out in steps of exponentially increasing size.\n", "2. Sample \\\\[x' \\sim \\frac{g_u(x)}{\\int g_u(x') dx'} \\text{ and } z \\sim \\text{Unifrom}(0, 1),\\\\] and perform the following squeezing and rejection tests. If \\\\[z \\leq \\frac{g_l(x')}{g_u(x')}\\\\] holds, then accept $x'$ otherwise perform the following rejection test \\\\[z \\leq \\frac{h(x')}{g_u(x')}\\\\] If this holds, accept the point and otherwise reject it.\n", "3. If $x'$ was accepted at the squeezing test, go to step 2 immediately. Otherwise insert $x'$ into $T_k$ to obtain $T_{k+1}$, update the piecewise exponential functions $g_l$ and $g_u$ accordingly and then return to step 2.\n", " \n", "
\n", "
\n", "\n", "Below are functions which implement envelope sampling, that is drawing \n", "\n", "$$ x' \\sim \\frac{g_u(x)}{\\int g_u(x') dx'}. $$\n", "\n", "The first function determines the left and right limits as well as the the unnormalised probabilities $\\int g_{u, k}(x') dx'$ of each piecewise exponential. The second samples uniformly from the area under the envelope function $g_u$." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def envelope_limits_and_unnormalised_probabilities(xs, hs, dhdxs):\n", " \n", " # Compute the points of intersection of the lines making up the envelope\n", " z, c = compute_points_of_intersection_and_intercepts(xs, hs, dhdxs)\n", " \n", " # Left-right endpoints for each piece in the piecewise envelope\n", " limits = np.concatenate([[float('-inf')], z, [float('inf')]])\n", " limits = np.stack([limits[:-1], limits[1:]], axis=-1)\n", " \n", " probs = (np.exp(dhdxs * limits[:, 1]) - np.exp(dhdxs * limits[:, 0])) * np.exp(c)\n", " \n", " # Catch any intervals where dhdx was zero\n", " idx_nonzero = np.where(dhdxs != 0.)\n", " probs[idx_nonzero] = probs[idx_nonzero] / dhdxs[idx_nonzero]\n", " \n", " idx_zero = np.where(dhdxs == 0.)\n", " probs[idx_zero] = ((limits[:, 1] - limits[:, 0]) * np.exp(c))[idx_zero]\n", " \n", " return limits, probs\n", "\n", "\n", "def sample_envelope(xs, hs, dhdxs):\n", " \n", " limits, probs = envelope_limits_and_unnormalised_probabilities(xs, hs, dhdxs)\n", " probs = probs / np.sum(probs)\n", " \n", " # Randomly chosen interval in which the sample lies\n", " i = np.random.choice(np.arange(probs.shape[0]), p=probs)\n", " \n", " # Sample u = Uniform(0, 1)\n", " u = np.random.uniform()\n", " \n", " # Invert i^th piecewise exponential CDF to get a sample from that interval\n", " if dhdxs[i] == 0.:\n", " return u * (limits[i, 1] - limits[i, 0]) + limits[i, 0]\n", " \n", " else:\n", " x = np.log(u * np.exp(dhdxs[i] * limits[i, 1]) \\\n", " + (1 - u) * np.exp(dhdxs[i] * limits[i, 0]))\n", " x = x / dhdxs[i] \n", " \n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we draw samples from the envelope defined by the three previous points without the rejection step, we obtain the following distribution." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [ "hide-input", "center-output" ] }, "outputs": [ { "data": { "application/pdf": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-07-31T20:41:55.459426\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_plot = np.linspace(-4., 4., 200)\n", "_, probs = envelope_limits_and_unnormalised_probabilities(xs, hs, dhdxs)\n", "\n", "samples = [sample_envelope(xs, hs, dhdxs) for i in range(10000)]\n", "gu = [g_u(x, xs, hs, dhdxs) / np.sum(probs) for x in x_plot]\n", "\n", "# Plot samples and envelope\n", "plt.figure(figsize=(6, 3))\n", "plt.plot(x_plot,\n", " gu,\n", " color='red',\n", " label='Normalised $g_u$')\n", "\n", "plt.hist(samples,\n", " density=True,\n", " bins=100,\n", " color='gray',\n", " alpha=0.5,\n", " label='Envelope samples')\n", "\n", "# Plot formatting\n", "plt.title('', fontsize=20)\n", "plt.xlim([-4, 4])\n", "plt.ylim([0, 0.5])\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.xlabel('$x$', fontsize=18)\n", "plt.ylabel('$f(x)~/~Z$', fontsize=18)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We still need to add the initialisation of the abcissa set, the (optional) squeezing test and the rejection test. For the initialisation step, we can start from an initial point, and search to the left and to the right in exponentially increasing step sizes, until we find a point on the left side with positive $h'$ and a point on the right with negative $h'$, and use these as end-points of the abscissa set. The following function `adaptive_rejection_sampling` implements this initialisation step together with the squeezing and rejection tests." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def initialise_abcissa(x0, log_unnorm_prob):\n", " \n", " # Expand to the left/right until the abcissa is correctly initialised\n", " xs = np.array([x0])\n", " hs, dhdxs = log_unnorm_prob(xs)\n", " \n", " dx = -1.\n", " \n", " while True:\n", " \n", " if dx < 0. and dhdxs[0] > 0.:\n", " dx = 1.\n", " \n", " elif dx > 0. and dhdxs[-1] < 0.:\n", " break\n", " \n", " insert_idx = 0 if dx < 0 else len(xs)\n", " \n", " x = xs[0 if dx < 0 else -1] + dx\n", " \n", " h, dhdx = log_unnorm_prob(x)\n", " \n", " xs = np.insert(xs, insert_idx, x)\n", " hs = np.insert(hs, insert_idx, h)\n", " dhdxs = np.insert(dhdxs, insert_idx, dhdx)\n", " \n", " dx = dx * 2\n", " \n", " return xs, hs, dhdxs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This approach to initialising the abcissa set does not have any tunable parameters, except `x0`. Any initialisation method which guarantees $h'(x_1) > 0$ and $x'(x_K) < 0$ will give a valid abcissa set and this method is only a specific choice. Changing the `x0` value used to the does not significantly affect the efficiency of ARS, since the initialisation method will terminate quickly because of the exponentially increasing step sizes. This implementation assumes that the domain $D$ of $f$, that is the set of points where $f$ is non-zero, is all of $\\mathbb{R}$. If this is not the case, then this initialisation function will fail. A more robust initialisation method could use boolean comparisons of $h'$ values instead of $h$ values, setting $h = -\\infty$ outside $D$, but for the purposes of exposition, this illustration assumes that $D = \\mathbb{R}$ and does not bother further with this technical point. Putting the initialisation step together with the squeezing and rejection steps, we arrive at the complete ARS algorithm below." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def adaptive_rejection_sampling(x0, log_unnorm_prob, num_samples):\n", " \n", " xs, hs, dhdxs = initialise_abcissa(x0=x0, log_unnorm_prob=log_unnorm_prob)\n", " \n", " samples = []\n", "\n", " while len(samples) < num_samples:\n", " \n", " x = sample_envelope(xs, hs, dhdxs)\n", " \n", " gl = g_l(x, xs, hs)\n", " gu = g_u(x, xs, hs, dhdxs)\n", "\n", " # Squeezing test\n", " u = np.random.rand()\n", "\n", " if u * gu <= gl:\n", " samples.append(x)\n", " \n", " h, dhdx = log_unnorm_prob(x)\n", "\n", " # Rejection test\n", " if u * gu <= np.exp(h):\n", " samples.append(x)\n", "\n", " i = np.searchsorted(xs, x)\n", "\n", " xs = np.insert(xs, i, x)\n", " hs = np.insert(hs, i, h)\n", " dhdxs = np.insert(dhdxs, i, dhdx)\n", " \n", " return samples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can finally use this function to sample from the example standard Gaussian distribution." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "np.random.seed(0)\n", "\n", "target_mean = 0.\n", "target_variance = 1.\n", "\n", "x0 = 1.\n", "num_samples = 10000\n", "\n", "log_unnorm_prob = log_gaussian(mean=target_mean, variance=target_variance)\n", "\n", "samples = adaptive_rejection_sampling(x0=x0, log_unnorm_prob=log_unnorm_prob, num_samples=num_samples)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [ "hide-input", "center-output" ] }, "outputs": [ { "data": { "application/pdf": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-07-31T17:46:56.515627\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Log probabilites for plotting the target\n", "x_plot = np.linspace(-4, 4, 200)\n", "log_probs = [np.exp(log_prob(x)[0]) / (2 * np.pi) ** 0.5 for x in x_plot]\n", "\n", "# Plot samples and target\n", "plt.figure(figsize=(6, 3))\n", "\n", "plt.hist(samples,\n", " density=True,\n", " bins=50,\n", " color='gray',\n", " alpha=0.5,\n", " label='Samples')\n", "plt.plot(x_plot,\n", " log_probs,\n", " color='black',\n", " label='Normalised $f(x)$')\n", "\n", "# Plot formatting\n", "plt.xlim([-4, 4])\n", "plt.ylim([0, 0.5])\n", "plt.xticks([])\n", "plt.yticks([])\n", "plt.xlabel('$x$', fontsize=18)\n", "plt.ylabel('$f(x)~/~Z$', fontsize=18)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusions\n", "\n", "ARS is an efficient method for sampling log-concave univariate distributions. Although very effective for log-concave one-dimensional differentiable distributions, the algorithm presented here has two shortcomings. First, this algorithm requires gradients of the objective with respect to the input variable. These may be expensive to compute or perhaps even may not exist if $f$ is nowhere differentiable. For this, there exists a modified version of ARS{cite}`gilks1992adaptive` which builds the envelope in a way that does not require gradients. The present page presented the gradient-based method because this is nicer for illustrative purposes. Second, although many distributions of practical interest are log-concave, there are many others which are not. In this case, the ARS algorithm does not apply since the envelope is not guaranteed to entirely contain the probability distribution. For this, there exists an extension of ARS for non-log-concave distributions called the adaptive rejection Metropolis algorithm{cite}`gilks1995adaptive` (ARMS). ARMS also builds an envelope and uses it to propose samples, which are then accepted or rejected using a Metropolis-Hastings step to ensure that the samples are distributed according to the target. Note that in general, ARMS does not produce independent samples from the target, due to the Metropolis-Hastings accept/reject step. For log-concave functions, ARMS reduces to ARS." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "```{bibliography} ./ars.bib\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 4 }