Gradient Estimator of Discrete Random Variables

Renjie Liao · July 2, 2020

Machine Learning   Statistics

In this post, I will review several popular gradient estimators for discrete random variables. In machine learning, especially latent variable models and reinforcement learning (RL), we are often facing the following situation. We have a discrete random variable \(Z\) which takes values from \(K\) categories where \(K\) could be finite or countably infinite. Assuming the distribution associated with \(Z\) is \(q_{\phi}(Z)\), we would like to optimize the expected function as follows,

\[\begin{align}\label{eq:objective} \mathcal{L}(\phi) = \mathbb{E}_{q_{\phi}(Z)} \left[ f(Z) \right]. \end{align}\]

For example, in RL, \(Z\) is the discrete action, \(q_{\phi}(Z)\) is the policy, and \(f(Z)\) is the reward. To optimize the objective using the powerful gradient based methods, we need to compute the gradient. In particular, if we can work out the expectation in Eq. (\(\ref{eq:objective}\)) exactly, then the gradient is simple,

\[\begin{align}\label{eq:gradient_exact} \nabla \mathcal{L}(\phi) = \sum_{k=1}^{K} \frac{\partial q_{\phi}(Z=k)}{\partial \phi} f(Z=k). \end{align}\]

However, life is hard as always. In reality, it is often intractable due to \(K\) is very large or even countably infinite. Therefore, people have developed various tricks to deal with this problem.

Score Function Estimator

Basic Estimator

The most straightforward (also very elegant) solution is called score function estimator (or REINFORCE algorithm if your background is more towards machine learning rather than statistics). It works as follows,

\[\begin{align}\label{eq:gradient_reinforce} \nabla \mathcal{L}(\phi) & = \sum_{k=1}^{K} \left. \frac{\partial q_{\phi}(Z)}{\partial \phi} f(Z) \right\rvert_{Z=k} \nonumber \\ & = \sum_{k=1}^{K} \left. q_{\phi}(Z) \frac{\partial \log q_{\phi}(Z)}{\partial \phi} f(Z) \right\rvert_{Z=k} \nonumber \\ & = \mathbb{E}_{q_{\phi}(Z)} \left[ \frac{\partial \log q_{\phi}(Z)}{\partial \phi} f(Z) \right] \nonumber \\ & \approx \frac{1}{M} \sum_{m=1}^{M} \frac{\partial \log q_{\phi}(z_m)}{\partial \phi} f(z_m) \qquad z_m \sim q_{\phi}(Z). \end{align}\]

Therefore, the score function estimator (estimator in statistics is defined w.r.t. random variable not samples) is

\[\begin{align}\label{eq:reinforce_estimator} g_{\text{score}}(Z) = \frac{1}{M} \sum_{m=1}^{M} \frac{\partial \log q_{\phi}(Z)}{\partial \phi} f(Z). \end{align}\]

The origin of its name is that \(\frac{\partial \log q_{\phi}(Z)}{\partial \phi}\) is called the score function in the literature of statistics. You might think this is a dumb trick as it just leverages the gradient of log function. But it is actually quite brilliant since we can now deal with large (or even infinite) \(K\) with tractably small values of \(M\). As with many other Monte Carlo estimators, the larger the \(M\) is, the more accurate \(g_{\text{score}}\) is. Moreover, the estimator is unbiased, i.e.,

\[\begin{align} \mathbb{E} \left[ g_{\text{score}}(Z) \right] = \nabla \mathcal{L}(\phi). \end{align}\]

However, this estimator also has a notorious drawback that the variance is often large in practice which makes the learning/optimization process converges very slowly. People often try it in the first place due to its simplicity in math and implementation.

Score Function Estimator with Variance Reduction

To reduce the variance of the estimator, people have proposed various ways, e.g. see this paper. One of the simplest ways is to add a control variate (also called baseline in the context of RL). The idea is simple,

\[\begin{align}\label{eq:gradient_reinforce_baseline} \nabla \mathcal{L}(\phi) & = \mathbb{E}_{q_{\phi}(Z)} \left[ \frac{\partial \log q_{\phi}(Z)}{\partial \phi} f(Z) \right] \nonumber \\ & = \mathbb{E}_{q_{\phi}(Z)} \left[ \frac{\partial \log q_{\phi}(Z)}{\partial \phi} \left( f(Z) - C \right) \right] \nonumber \\ & \approx \frac{1}{M} \sum_{m=1}^{M} \frac{\partial \log q_{\phi}(z_m)}{\partial \phi} \left( f(z_m) - C \right) \qquad z_m \sim q_{\phi}(Z). \end{align}\]

In the second line, we use the fact that

\[\begin{align}\label{eq:baseline_trick} \mathbb{E}_{q_{\phi}(Z)} \left[ \frac{\partial \log q_{\phi}(Z)}{\partial \phi} \right] & = \sum_{k=1}^{K} q_{\phi}(Z=k) \frac{1}{q_{\phi}(Z=k)} \frac{\partial q_{\phi}(Z=k)}{\partial \phi} \nonumber \\ & = \sum_{k=1}^{K} \frac{\partial q_{\phi}(Z=k)}{\partial \phi} \nonumber \\ & = \frac{\partial \sum_{k=1}^{K} q_{\phi}(Z=k)}{\partial \phi} \nonumber \\ & = \frac{\partial 1}{\partial \phi} = 0. \end{align}\]

Therefore, the new control variate estimator is,

\[\begin{align}\label{eq:control_variate_estimator} g_{\text{score_cv}}(Z) = \frac{1}{M} \sum_{m=1}^{M} \frac{\partial \log q_{\phi}(Z)}{\partial \phi} \left( f(Z) - C \right). \end{align}\]

The idea is to choose a control variate \(C\) (as long as \(C\) does not depend on \(Z\)) so that the variance of the estimator can be reduced. Some popular choices in practices are the empirical mean of the function \(f\) or even some learnable ones (e.g., in the context of RL). You might wonder what the optimal estimator is in terms of variance reduction. This general topic is one of the main problems in the mathematical statistics, called minimum-variance unbiased estimator (MVUE) or uniformly minimum-variance unbiased estimator (UMVUE). I will cover this topic in the future.

Rao-Blackwellization

Once you got an unbiased estimator, a typical strategy for reducing the variance is Rao-Blackwellization. The idea goes as follows in our context. Given an unbiased estimator, e.g., either \(g_{\text{score}}\) or \(g_{\text{score_cv}}\), we construct its Rao-Blackwell estimator as follows,

\[\begin{align} g_{\text{score_rb}} (Z) = \mathbb{E} \left[ g_{\text{score_cv}} (Z) \vert T(Z) \right], \end{align}\]

where \(T(Z)\) is a sufficient statistic of the parameter to be estimated, i.e., \(\nabla \mathcal{L}(\phi)\). In other words, \(q_{\phi}(Z \vert T(Z))\) does not depend on \(\nabla \mathcal{L}(\phi)\). Intuitively, a sufficient statistic is capturing all information in \(Z\) which is relevant for \(\nabla \mathcal{L}(\phi)\). A statistic just means a function of the data. The expectation is w.r.t. \(g_{\text{score_cv}} (Z)\). But we can equivalently transform it to the expectation w.r.t. \(Z\) using the so-called law of the unconscious statistician (LOTUS).

Then the Rao-Blackwell Theorem guarantees that the variance of the new estimator will be no larger (typically smaller) than the original one. Moreover, the unbiasedness will not be changed. This process of transforming an estimator using Rao-Blackwell Theorem is called Rao-Blackwellization. I will prove the Rao-Blackwell Theorem in another post.

In particular, in our context, we show one specific construction of the “Rao-Blackwell estimator” following this work. We first split the support of the distribution \(q_{\phi}(Z)\) into set \(\mathcal{C}_k\) and its complement \(\bar{\mathcal{C}}_k\) so that \(\mathcal{C}_k\) contains categories with the top \(k\) probabilities. Denoting \(\epsilon = \sum_{z \in \bar{\mathcal{C}}_k} q_{\phi}(Z = z)\), we could construct two following distributions,

\[\begin{align}\label{eq:aux_distribution} q_{\phi \vert \mathcal{C}_k}(U) & = \frac{1}{1 - \epsilon} q_{\phi}(U) & \qquad U \in \mathcal{C}_k \nonumber \\ q_{\phi \vert \bar{\mathcal{C}}_k}(V) & = \frac{1}{\epsilon} q_{\phi}(V) & \qquad V \in \bar{\mathcal{C}}_k. \end{align}\]

Now we have an equivalent representation of \(Z = U^{1-B} V^{B}\) where \(B\) is a Bernoulli random variable with \(p_{\epsilon}(B) = \text{Bernoulli}(\epsilon)\). Hence, the “Rao-Blackwell estimator” is,

\[\begin{align}\label{eq:rao_blackwell_estimator} g_{\text{score_rb}} (Z) & = \mathbb{E}_{Z} \left[ g_{\text{score_cv}}(Z) \vert V \right] \nonumber \\ & = \mathbb{E}_{U, B} \left[ g_{\text{score_cv}}(U^{1-B}V^{B}) \vert V \right] \nonumber \\ & = (1 - \epsilon) \mathbb{E}_{U} \left[ g_{\text{score_cv}}(U) \vert V \right] + \epsilon \mathbb{E}_{U} \left[ g_{\text{score_cv}}(V) \vert V \right] \nonumber \\ & = (1 - \epsilon) \sum_{u \in \mathcal{C}_k} g_{\text{score_cv}}(u) \frac{q_{\phi}(u)}{1 - \epsilon} + \epsilon g_{\text{score_cv}}(V) \nonumber \\ & = \sum_{u \in \mathcal{C}_k} g_{\text{score_cv}}(u) q_{\phi}(u) + \epsilon g_{\text{score_cv}}(V), \end{align}\]

As you can see that (1) the first term is deterministic and second is stochastic and (2) the estimator requires \(k + 1\) evaluations of the original estimator. Note that one can replace \(g_{\text{score_cv}}(v)\) as any other unbiased estimator. Technically speaking, it is not a Rao-Blackwell estimator since \(v\) is not a sufficient statistic. It is easy to show that the sufficient statistic of the categorical distribution (instance of the exponential family) \(q_{\phi}(Z)\) is \(\left[\bm{1}[Z = 1], \bm{1}[Z = 2], \dots, \bm{1}[Z = K] \right]\). However, the above estimator still enjoys the benefit of the Rao-Blackwellization due to the law of total variance,

\[\begin{align}\label{eq:variance_reduction} \mathbb{V} \left( g_{\text{score_rb}} (Z) \right) & = \mathbb{V} \left( g_{\text{score_cv}} (Z) \right) + \mathbb{E} \left[ \mathbb{V} \left( g_{\text{score_cv}}(Z) \vert V \right) \right], \end{align}\]

where the second term is clearly nonnegative. Specifically, authors show that

\[\begin{align}\label{eq:variance_ineq} \mathbb{V} \left( g_{\text{score_cv}} (Z) \right) & = \mathbb{E} \left[ \mathbb{V} \left( g_{\text{score_cv}}(Z) \vert B \right) \right] + \mathbb{V} \left( \mathbb{E} \left[ g_{\text{score_cv}}(Z) \vert B \right] \right) \nonumber \\ & \ge \mathbb{E} \left[ \mathbb{V} \left( g_{\text{score_cv}}(Z) \vert B \right) \right] \nonumber \\ & = \epsilon \mathbb{V} \left( g_{\text{score_cv}}(Z) \vert B = 1 \right) + (1 - \epsilon) \mathbb{V} \left( g_{\text{score_cv}}(Z) \vert B = 0 \right) \nonumber \\ & \ge \epsilon \mathbb{V} \left( g_{\text{score_cv}}(Z) \vert B = 1 \right) \nonumber \\ & = \epsilon \mathbb{V} \left( g_{\text{score_cv}}(Z) \vert Z = V \right) \nonumber \\ & = \epsilon \mathbb{V} \left( g_{\text{score_cv}}(V) \right) \nonumber \\ & = \frac{1}{\epsilon} \mathbb{V} \left( g_{\text{score_rb}}(V) \right) \end{align}\]

where the 1st line uses the law of total variance again, the 5th line uses the fact that \(Z = V\) is equivalent to \(B = 1\), and the last line uses the equation (13) along with the facts that \(\mathbb{V}(a + X) = \mathbb{V}(X)\) and \(\mathbb{V}(aX) = a^2\mathbb{V}(X)\). Since \(\epsilon\) is strictly less than \(1\), we surely reduce the variance by using \(g_{\text{score_rb}}(V)\). Moreover, authors show that the above estimator is no worse than the simple minibatching estimator \(\frac{1}{M} \sum_{m=1}^{M} g_{\text{score_cv}} (Z)\).

Sampling Without Replacements

TBD

Relaxation Based Estimator

Gumbel-Softmax

TBD

Combination of Both

REBAR & RELAX

TBD