FFVB: Control variate
This section describes a control variate technique for variance reduction.
See: Variational Bayes Introduction, Fixed Form Variational Bayes
Control variate
$\def\t{\theta} \def\LB{\text{LB}} \def\E{\mathbb{E}} \def\KL{\text{KL}} \newcommand{\wh}{\widehat} \newcommand{\wt}{\widetilde} \def\F{\cal{F}} \def\N{\cal{N}} \def\s{\sigma} \def\a{\alpha} \def\b{\beta} \def\l{\lambda} \def\d{d} \newcommand{\eps}{\epsilon} \def\veps{\varepsilon} \def\vech{\text{vech}} \def\diag{\text{diag}} \def\V{\mathbb{V}} \def\cov{\text{cov}}$ As is typical of stochastic optimization algorithms, the performance of Algorithm 3 depends greatly on the variance of the noisy gradient. Variance reduction for the noisy gradient is a key ingredient in FFVB algorithms.
Let $\t_s\sim q_\l(\t)$, $s=1,…,S$, be $S$ samples from the variational distribution $q_{\l}(\t)$. A naive estimator of the $i$th element of the vector $\nabla_\lambda\text{LB}(\l)$ is \(\tag{20}\label{eq: naive KL estimate}\wh{\nabla_{\lambda_i}\LB}(\l)^{\text{naive}}=\frac1S\sum_{s=1}^S\nabla_{\lambda_i}[\log q_\lambda(\t_s)]\times h_\lambda(\t_s),\) whose variance is often too large to be useful.
For any number $c_i$, consider
\[\tag{21}\label{eq: reduced var KL estimate} \wh{\nabla_{\lambda_i}\LB}(\l)=\frac1S\sum_{s=1}^S\nabla_{\lambda_i}[\log q_\lambda(\t_s)](h_\lambda(\t_s)-c_i),\]which is still an unbiased estimator of $\nabla_{\lambda_i}\text{LB}(\l)$ since $\E(\nabla_{\lambda}[\log q_\lambda(\t)])=0$, whose variance can be greatly reduced by an appropriate choice of control variate $c_i$.
The variance of $\wh{\nabla_{\lambda_i}\LB}(\l)$ is
\[\frac1S\V\Big(\nabla_{\lambda_i}[\log q_\lambda(\t)]h_\lambda(\t)\Big)+\frac{c_i^2}{S}\V\Big(\nabla_{\lambda_i}[\log q_\lambda(\t)]\Big)-\frac{2c_i}{S}\cov\Big(\nabla_{\lambda_i}[\log q_\lambda(\t)]h_\lambda(\t),\nabla_{\lambda_i}[\log q_\lambda(\t)]\Big).\]The optimal $c_i$ that minimizes this variance is
\[\tag{22}\label{eq:optimal c_i} c_i=\cov\Big(\nabla_{\lambda_i}[\log q_\lambda(\t)] h_\lambda(\t),\nabla_{\lambda_i}[\log q_\lambda(\t)]\Big)\Big/\V\Big(\nabla_{\lambda_i}[\log q_\lambda(\t)]\Big).\]Then
\[\V\left(\wh{\nabla_{\lambda_i}\LB}(\l)\right)=\V\left(\wh{\nabla_{\lambda_i}\LB}(\l)^{\text{naive}}\right)(1-\rho^2_i) \leq \V\left(\wh{\nabla_{\lambda_i}\LB}(\l)^{\text{naive}}\right),\]where $\rho_i$ is the correlation between $\nabla_{\lambda_i}[\log q_\lambda(\t)]h_\lambda(\t)$ and $\nabla_{\lambda_i}[\log q_\lambda(\t)]$. Often, $\rho_i^2$ is very close to 1, which leads to a large variance reduction.
One can estimate the numbers $c_i$ in $\eqref{eq:optimal c_i}$ using samples $\t_s\sim q_{\l}(\t)$. In order to ensure the unbiasedness of the gradient estimator, the samples used to estimate $c_i$ must be independent of the samples used to estimate the gradient.
In practice, the $c_i$ can be updated sequentially as follows. At iteration $t$, we use the $c_i$ computed in the previous iteration $t-1$, i.e. based on the samples from $q_{\l^{(t-1)}}(\t)$, to estimate the gradient $\wh{\nabla_\l\text{LB}}(\l^{(t)})$, which is computed using new samples from $q_{\l^{(t)}}(\t)$.
We then update the $c_i$ using this new set of samples. By doing so, the unbiasedness is guaranteed while no extra samples are needed in updating the control variates $c_i$.
Algorithm 4 provides a detailed pseudo-code implementation of the FFVB approach that uses the control variate for variance reduction and moving average adaptive learning, and Algorithm 5 implements the FFVB approach that uses the control variate and natural gradient.
Algorithm 4: FFVB with control variates and adaptive learning
- Input: Initial $\l^{(0)}$, adaptive learning weights $\beta_1,\beta_2\in(0,1)$, fixed learning rate $\eps_0$, threshold $\tau$, rolling window size $t_W$ and maximum patience $P$. Model-specific requirement: function $h(\theta):=\log\big(p(\theta)p(y\mid\theta)\big)$.
- Initialization
- Generate $\theta_s\sim q_{\lambda^{(0)}}(\theta)$, $s=1,…,S$.
-
Compute the unbiased estimate of the LB gradient
\[\wh{\nabla_\l\text{LB}}(\l^{(0)}):=\frac{1}{S}\sum_{s=1}^S\nabla_\lambda \log q_\lambda(\theta_s)\times h_\lambda(\theta_s)|_{\lambda=\lambda^{(0)}}.\] - Set $g_0:=\wh{\nabla_\l\text{LB}}(\l^{(0)})$, $v_0:=(g_0)^2$, $\bar g:=g_0$, $\bar v:=v_0$.
- Estimate the vector of control variates $c$ as in \eqref{eq:optimal c_i} using the samples ${\theta_s,s=1,…,S}$.
- Set $t=0$, $\text{patience}=0$ and $\texttt{stop=false}$.
- While $\texttt{stop=false}$:
- Generate $\theta_s\sim q_{\lambda^{(t)}}(\theta)$, $s=1,…,S$.
-
Compute the unbiased estimate of the LB gradient
\[g_t:=\wh{\nabla_\l\text{LB}}(\l^{(t)})=\frac{1}{S}\sum_{s=1}^S\nabla_\lambda \log q_\lambda(\theta_s)\circ \big(h_\lambda(\theta_s)-c\big)|_{\lambda=\lambda^{(t)}}.\] - Estimate the new control variate vector $c$ as in \eqref{eq:optimal c_i} using the samples ${\theta_s,s=1,…,S}$.
-
Compute $v_t=(g_t)^2$ and
\[\bar g =\beta_1 \bar g+(1-\beta_1)g_t,\;\;\bar v =\beta_2 \bar v+(1-\beta_2)v_t.\] -
Compute $\alpha_t=\min(\epsilon_0,\epsilon_0\frac{\tau}{t})$ and update
\[\l^{(t+1)}=\l^{(t)}+\a_t \bar g/\sqrt{\bar v}\] -
Compute the lower bound estimate
\(\wh{\text{LB}}(\l^{(t)}):=\frac{1}{S}\sum_{s=1}^S h_{\lambda^{(t)}}(\theta_s).\) \item If $t\geq t_W$: compute the moving averaged lower bound \(\overline {\text{LB}}_{t-t_W+1}=\frac{1}{t_W}\sum_{k=1}^{t_W} \wh{\text{LB}}(\l^{(t-k+1)}),\)
and if $\overline {\text{LB}}_{t-t_W+1}\geq\max(\overline\LB)$ patience = 0; else $\text{patience}:=\text{patience}+1$.
- If $\text{patience}\geq P$, $\texttt{stop=true}$.
- Set $t:=t+1$.
Note: The term $\nabla_\lambda \log q_\lambda(\theta_s)\circ \big(h_\lambda(\theta_s)-c\big)$ should be understood component-wise, i.e. it is the vector whose $i$th element is $\nabla_{\lambda_i} \log q_\lambda(\theta_s)\times \big(h_\lambda(\theta_s)-c_i\big)$.
Algorithm 5: FFVB with control variates and natural gradient
- Input: Initial $\l^{(0)}$, momentum weight $\alpha_m$, fixed learning rate $\eps_0$, threshold $\tau$, rolling window size $t_W$ and maximum patience $P$. Model-specific requirement: function function $h(\theta):=\log\big(p(\theta)p(y\mid\theta)\big)$.
- Initialization
- Generate $\theta_s\sim q_{\lambda^{(0)}}(\theta)$, $s=1,…,S$.
-
Compute the unbiased estimate of the LB gradient
\[\wh{\nabla_\l\text{LB}}(\l^{(0)}):=\frac{1}{S}\sum_{s=1}^S\nabla_\lambda \log q_\lambda(\theta_s)\times h_\lambda(\theta_s)|_{\lambda=\lambda^{(0)}}.\]and the natural gradient
\[\wh{\nabla_{\lambda}\LB} (\l^{(0)})^{\text{nat}} := I_F^{-1}(\l^{(0)})\wh{\nabla_\l\LB}(\l^{(0)}).\] - Set momentum gradient
- Estimate control variate vector $c$ as in \eqref{eq:optimal c_i} using the samples ${\theta_s,s=1,…,S}$.
- Set $t=0$, $\text{patience} =0$ and $\texttt{stop=false}$.
- While $\texttt{stop=false}$:
- Generate $\theta_s\sim q_{\lambda^{(t)}}(\theta)$, $s=1,…,S$.
-
Compute the unbiased estimate of the LB gradient
\[\wh{\nabla_\l\text{LB}}(\l^{(t)})=\frac{1}{S}\sum_{s=1}^S\nabla_\lambda \log q_\lambda(\theta_s)\circ \big(h_\lambda(\theta_s)-c\big )|_{\lambda=\lambda^{(t)}}\]and the natural gradient
\[\wh{\nabla_{\lambda}\LB} (\l^{(t)})^{\text{nat}} = I_F^{-1}(\l^{(t)})\wh{\nabla_\l\LB}(\l^{(t)}).\] - Estimate the new control variate vector $c$ as in \eqref{eq:optimal c_i} using the samples ${\theta_s,s=1,…,S}$.
-
Compute the momentum gradient
\[\overline{{\nabla_\l{\LB}}} = \alpha_\text{m} \overline{{\nabla_\l{\LB}}}+(1-\alpha_\text{m})\wh{\nabla_{\lambda}\LB}(\l^{(t)})^{\text{nat}}.\] -
Compute $\alpha_t=\min(\epsilon_0,\epsilon_0\frac{\tau}{t})$ and update
\[\l^{(t+1)}=\l^{(t)}+\a_t \overline{{\nabla_\l{\LB}}}.\] -
Compute the lower bound estimate
\[\wh{\text{LB}}(\l^{(t)}):=\frac{1}{S}\sum_{s=1}^S h_{\lambda^{(t)}}(\theta_s).\] -
If $t\geq t_W$: compute the moving average lower bound
\[\overline {\text{LB}}_{t-t_W+1}=\frac{1}{t_W}\sum_{k=1}^{t_W} \wh{\text{LB}}(\l^{(t-k+1)}),\]and if $\overline {\text{LB}}_{t-t_W+1}\geq\max(\overline\LB)$ patience = 0; else $\text{patience}:=\text{patience}+1$.
- If $\text{patience}\geq P$, $\texttt{stop=true}$.
- Set $t:=t+1$.
Next: FFVB with reparameterization trick