Link Search Menu Expand Document

FFVB: Reparameterization Trick

This section describes a control variate technique for variance reduction.

See: Variational Bayes Introduction, Fixed Form Variational Bayes


Reparameterization Trick

The reparameterization trick is an attractive alternative to the control variate method. Suppose that for θqλ(), there exists a deterministic function g(λ,ε) such that θ=g(λ,ε)qλ() where εpε().

We emphasize that pε() must not depend on λ. For example, if qλ(θ)=N(θ;μ,σ2) then θ=μ+σε with εN(0,1). Writing LB(λ) as an expectation with respect to pε()

LB(λ)=Eεpε(hλ(g(ε,λ))),

where Eεpε() denotes expectation with respect to pε(), and differentiating under the integral sign gives

λLB(λ)=Eεpε(λg(λ,ε)θhλ(θ))+Eεpε(λhλ(θ))

where the θ within hλ(θ) is understood as θ=g(ε,λ) with λ fixed.

In particular, the gradient λhλ(θ) is taken when θ is not considered as a function of λ. Here, with some abuse of notation, λg(λ,ε) denotes the Jacobian matrix of size dθ×dλ of the vector-valued function θ=g(λ,ε). Note that

Eεpε(λhλ(θ))=Eεqε(λhλ(θ=g(ε,λ)))=Eεqε(λlogqλ(θ=g(ε,λ)))=Eθqλ(λlogqλ(θ))=0,

hence

(24)λLB(λ)=Eεqε(λg(λ,ε)θhλ(θ)).

The gradient (24) can be estimated unbiasedly using i.i.d samples εspε(), s=1,,S, as

(25)λLB^(λ)=1Ss=1Sλg(λ,εs)θ{hλ(g(λ,εs))}.

The reparametrization gradient estimator (25) is often more efficient than alternative approaches to estimating the lower bound gradient, partly because it takes into account the information from the gradient θhλ(θ).

In typical VB applications, the number of Monte Carlo samples S used in estimating the lower bound gradient can be as small as 5 if the reparameterization trick is used, while the control variates method requires an S of about hundreds or more. However, there is a dilemma about choosing S that we must be careful of.

With the reparameterization trick, a small S might be enough for estimating the lower bound gradient, however, we still need a moderate S in order to obtain a good estimate of the lower bound if lower bound is used in the stopping criterion. Also, compared to score-function gradient, FFVB approaches that use reparameterization gradient require not only the model-specific function h(θ) but also its gradient θh(θ).

Algorithm 6 provides a detailed implementation of the FFVB approach that uses the reparameterization trick and adaptive learning. A small modification of Algorithm 6 (not presented) gives the implementation of the FFVB approach that uses the reparameterization trick and natural gradient.

Algorithm 6: FFVB with reparameterization trick and adaptive learning

  • Input: Initial λ(0), adaptive learning weights β1,β2(0,1), fixed learning rate ϵ0, threshold τ, rolling window size tW and maximum patience P. Model-specific requirement: function h(θ):=log(p(θ)p(yθ)) and its gradient θh(θ).
  • Initialization
    • Generate εspε(), s=1,,S.
    • Compute the unbiased estimate of the LB gradient

      λLB^(λ(0)):=1Ss=1Sλg(λ,εs)θ{hλ(g(λ,εs))}|λ=λ(0)
    • Set g0:=λLB^(λ(0)), v0:=(g0)2, g¯:=g0, v¯:=v0.
    • Set t=0, patience=0 and stop=false.
  • While stop=false:
    • Generate εspε(), s=1,,S
    • Compute the unbiased estimate of the LB gradient

      gt:=λLB^(λ(t))=1Ss=1Sλg(λ,εs)θ{hλ(g(λ,εs))}|λ=λ(t)
    • Compute vt=(gt)2 and

      g¯=β1g¯+(1β1)gt,v¯=β2v¯+(1β2)vt.
    • Compute αt=min(ε0,ε0τt) and update

      λ(t+1)=λ(t)+αtg¯/v¯
    • Compute the lower bound estimate

      LB^(λ(t)):=1Ss=1Shλ(t)(θs),θs=g(λ(t),εs).
    • If ttW: compute the moving average lower bound

      LBttW+1=1tWk=1tWLB^(λ(tk+1)),

      and if LBttW+1max(LB) patience = 0; else patience:=patience+1.

    • If patienceP, stop=true
    • Set t:=t+1

Next: GVB with Cholesky decomposed covariance