• No results found

Bayesian Theory and Computation. Lecture 14: Variational Inference

N/A
N/A
Protected

Academic year: 2021

Share "Bayesian Theory and Computation. Lecture 14: Variational Inference"

Copied!
33
0
0

Loading.... (view fulltext now)

Full text

(1)

Bayesian Theory and Computation

Lecture 14: Variational Inference

Cheng Zhang

School of Mathematical Sciences, Peking University May 14, 2021

(2)

Bayesian Inference

2/33

I A Bayesian probabilistic model includes the conditional distribution p(x|θ) of observed variable x given the model parameter θ, and the prior p(θ), which gives a joint distribution

p(x, θ) = p(x|θ)p(θ)

I Inference about the parameter θ is through the posterior, the conditional distribution of the parameters given the observations

p(θ|x) = p(x, θ) p(x)

I For most interesting models, the denominator is not tractable. We appeal to approximate posterior inference.

I Markov chain Monte Carlo – We’ve introduced I Variational inferenceThe topic for this lecture!

(3)

Variational Inference

3/33

q(θ) p(θ|x)

Q

q(θ) = arg min

q∈Q

KL (q(θ)kp(θ|x))

I VI turns inference into optimization

I Specify a variational familyof distributions over the model parameters

Q = {qφ(θ); φ ∈ Φ}

I Fit thevariational parameters φ to minimize the distance (often in terms of KL divergence) to the exact posterior

(4)

Example: Mixture of Gaussians

4/33

Adapted from David Blei

(5)

History

5/33

I Idea adapted from statistical physics– mean-field methods to fit a neural network (Peterson and Anderson, 1987).

I Picked up by Jordan’s lab in the early 1990s, generalized it to many probabilistic models. (see Jordan et al., 1999 for an overview)

I Contributions from Hinton’s group: mean-field for neural networks (Hinton and Van Camp, 1993); connection to the EM algorithm (Neal and Hinton, 1993).

(6)

Today

6/33

I More and more work on variational inference now, making it scalable, easier to derive, faster, more accurate, and applying it to more complicated models and applications.

I Modern VI touches many important areas: probabilistic programming, reinforcement learning, neural networks, convex optimization, Bayesian statistics, and myriad applications.

(7)

Kullback-Leibler Divergence

7/33 I We use Kullback-Leibler (KL) divergence to measure the

distance between two distributions.

I This comes from information theory, a field that has deep links to statistics and machine learning.

I The KL divergence for variational inference is KL(qkp) = Eq(θ)



log q(θ) p(θ|x)



I If q is high, p and q should be close I if q is low then we don’t care

I We choose q that are tractable: easy to take expectations and compute the pdf.

I Reversing the arguments leads to a different kind of

variational inference, i.e., expectation propagation, which is in general more computationally expensive.

(8)

Evidence Lower Bound

8/33

I We actually can’t minimize the KL divergence exactly, but we can find an equivalent formulation that is tractable – the evidence lower bound (ELBO)

I By Jensen’s inequality log p(x) = log

Z

p(x, θ)dθ

= log Z

q(θ)p(x, θ) q(θ) dθ

≥ Z

q(θ) logp(x, θ) q(θ) dθ

≥ Eq(log p(x, θ)) − Eq(log q(θ))

This is the ELBO. Note that this is the free-energy lower bound we derived for EM.

(9)

Equivalent Formulation

9/33

I What does the ELBO have to do with the KL divergence to the posterior?

I Note that p(θ|x) = p(x, θ)/p(x), use this in the KL divergence

KL(q(θ)kp(θ|x)) = Eq



log q(θ) p(θ|x)



= Eq



log q(θ) p(x, θ)



+ log p(x)

= log p(x) − Eq



logp(x, θ) q(θ)



I Therefore, the KL divergence is just the gap between the ELBO and the model evidence, and minimizing the KL divergence is equivalent tomaximizing the ELBO.

(10)

More on The ELBO

10/33

L = Eq



logp(x, θ) q(θ)



= Eq(log p(x, θ)) − Eq(log q(θ))

I L only requires the joint probability p(x, θ), which is computable.

I The ELBO trades off two terms

I The first term drives q towards the MAP estimate.

I The second term encourages q to be diffuse.

I Unfortunately, the ELBO is usually not convex, and VI may end up with local modes.

(11)

Mean-Field Variational Inference

11/33

I A commonly used variational family is the mean field approximation, a variational family that factorizes

q(θ) =

d

Y

i=1

qii)

Each variable is independent. We can relax this constraint by using blockwise factorization.

I Note that this family is usually quite limited since the parameters in true posteriors are likely to be dependent.

I E.g., in the Gaussian mixture model all of the cluster assignments z and the cluster locations µ are dependent on each other given the data x.

I These dependencies often make the posterior difficult to work with.

(12)

ELBO for Mean-Field VI

12/33

I We now turn to optimizing the ELBO for the mean field approximation

L = Eq(log p(x, θ)) − Eq d

X

i=1

log qii)

= Eq(log p(x, θ)) −

d

X

i=1

Eqilog qii) I For each component qii)

L = Z d

Y

i=1

qii) log p(x, θ)dθ −

d

X

i=1

Eqilog qii)

= EqiE−qi(log p(x, θ)) − Eqilog qii) + const

(13)

A Coordinate Ascent Algorithm

13/33

I Take the derivative w.r.t. qii)

∂L

∂qii) = E−qi(log p(x, θ)) − log qii) − 1 = 0 I This leads to a coordinate ascent algorithm

qii) ∝ exp (E−qi(log p(x, θ))) I The RHS only depends on qjj), j 6= i.

I This determines the form of the optimal qii). We only specify the factorization before.

I While the optimal qii) might not be easy to compute (depending on the form), for many models it is.

I The ELBO converges to a local minimum. We use the resulting q as a proxy for the true posterior.

(14)

Connection to Gibbs Sampling

14/33 There is a strong relationship between Mean-Field VI and

Gibbs sampling

I In Gibbs sampling, we sample from the conditional

I In Mean-Field VI (via coordinate ascent), we iteratively set each factor to

qii) ∝ exp (E(log(conditional))) Example: Multinomial conditionals

I Suppose the conditional is multinomial

p(θi−i, x) ∼ Multinomial(π(θ−i, x)) I Then the optimial qii) is also a multinomial

qi) ∝ exp (E(log π(θ−i, x)))

(15)

Exponential Family Conditionals

15/33

I Suppose each conditional is in the exponential family p(θi−i, x) = h(θi) exp (η(θ−i, x) · T (θi) − A(η(θ−i, x))) where η(θ−i, x) is the natural parameters and T (θi) is the sufficient statistics.

I This includes a lot of complicated models

I Bayesian mixture of exponential families with conjugate priors

I Hierarchical HMMs

I Mixed-membership models of exponential families I Bayesian linear regression

(16)

Coordinate Ascent for Mean-Field VI

16/33

I Compute the log of the conditional

log p(θi−i, x) = log h(θi) + η(θ−i, x) · T (θi) − A(η(θ−i, x)) I Compute the expectation w.r.t. q(θ−i)

E (log p(θi−i, x)) = log h(θi)+E (η(θ−i, x))·T (θi)−E (A(η(θ−i, x))) I Note that the last term does not depend on θi, therefore

qii) ∝ h(θi) exp (E (η(θ−i, x)) · T (θi)) and the normalizing constant is A(E (η(θ−i, x)))

I The optimal qii) is in the same exponential family as the conditional.

(17)

Examples: Bayesian Mixtures of Gaussians

17/33 I Consider the clustering of x = {x1, . . . , xn} using a finite

mixture of Gaussians with generating variance one zi∼ Discrete(π), xi|zi= k ∼ N (µk, 1) µk∼ N (µ0, σ0), k = 1, . . . , K

I The joint probability is

log p(x, z, µ) =

n

X

i=1

log p(xi, zi|µ) +

K

X

k=1

log N (µk0, σ20)

=

n

X

i=1 K

X

k=1

1zi=k(log πk+ log N (xik, 1))

+

K

X

k=1

log N (µk0, σ20)

(18)

Update q(z

i

)

18/33

I The mean field family is

q(µ, z) =

K

Y

k=1

N (µk|˜µk, ˜σk2)

n

Y

i=1

q(zii)

I Coordinate ascent update for q(zi) is

q(zi) ∝ exp E−q(zi)(log p(x, zi, z−i, µ)) I Take expectation and restrict the terms relate to zi

q(zi) ∝ exp(log πzi+ E(log N (xizi, 1)))

∝ exp log πzi+ xiµ˜zi− µ˜2zi+ ˜σz2i 2

!

(19)

Update q(µ

k

)

19/33

I Similarly, the coordinate ascent update for q(µk) is qk) ∝ exp E−q(µk)(log p(x, z, µk, µ−k))

∝ exp

n

X

i=1

q(zi = k) log N (xik, 1) + log N (µk0, σ02)

!

∝ exp

n

X

i=1

φi,k



xiµk−1 2µ2k

 + µ0

σ20µk− 1 2σ02µ2k

!

∝ N ˆµk, ˆσ2k where

ˆ µk =

µ0

σ20 +

n

P

i=1

φi,kxi

1 σ02 +

Pn i=1

φi,k

, σˆk2= 1 1 σ20 +

Pn i=1

φi,k

(20)

Examples: Bayesian Latent Dirichlet Allocation

20/33

I The local variables are θd and zd.

I The global variables are the topics β1, . . . , βK. I The variational distribution is

q(β, θ, z) =

K

Y

k=1

q(βkk)

D

Y

d=1

q(θdd)

N

Y

n=1

q(zd,nd,n)

(21)

Mean-field Variational Inference for LDA

21/33

Adapted from David Blei

(22)

Mean-field Variational Inference for LDA

22/33

I The complete probability model

θd∼ Dirichlet(α), βk ∼ Dirichlet(η)

zd,nd∼ Discrete(θd), wd,n|zd,n, β ∼ Discrete(βzd,n) I The joint probability is

p(w, z, θ, β) =

K

Y

k=1

p(βk|η)

D

Y

d=1

p(θd|α)

N

Y

n=1

p(zd,nd)p(wd,n|zd,n, β)

I We set q(βkk), q(θdd), q(zd,nd,n) accordingly q(βkk) ∼ Dirichlet(λk), q(θdd) ∼ Dirichlet(γd)

q(zd,nd,n) ∼ Discrete(φd,n)

(23)

Coordinate Ascent

23/33

I Update λ

q(βkk) ∝ exp Eq(β−k,θ,z)log p(w, z, θ, β)

∝ exp

V

X

j=1

j− 1 +

D

X

d=1 N

X

n=1

φd,n,kwd,nj ) log βk,j

⇒ λk,j = ηj +PD d=1

PN

n=1φd,n,kwjd,n I Update γ

q(θdd) ∝ exp(Eq(β,θ−d,z)log p(w, z, θ, β))

∝ exp

K

X

k=1

k− 1 +

N

X

n=1

φd,n,k) log θd,k

!

⇒ γd,k = αk+PN

n=1φd,n,k

(24)

Coordinate Ascent

24/33

I Update φ

q(zd,nd,n) ∝ exp(Eq(β,θ,z−(d,n))log p(w, z, θ, β))

∝ exp(Eq(β,θ,z−(d,n))(log p(zd,nd) + log p(wd,n|zd,n, β)))

∝ exp

K

X

k=1

1zd,n=k(Eθd(log θd,k) +

V

X

j=1

wjd,nEβk(log βk,j))

⇒ φd,n,k∝ exp

Eθd(log θd,k) +PV

j=1wjd,nEβk(log βk,j)



(25)

A Generic Class of Models

25/33

p(β, z, x) = p(β)

n

Y

i=1

p(zi, xi|β) I The observations are x = {x1, . . . , xn}

I The local latent variables are z = {z1, . . . , zn} I The global variables are β

I The i-th data point xi only depends on zi and β

(26)

Conditional Conjugacy

26/33

I Goal: compute p(β, z|x)

I Exponential family and conditional conjugacy

p(xi, zi|β) = h(xi, zi) exp (β · T (xi, zi) − A`(β)) p(β) = h(β) exp (α · T (β) − Ag(α))

= h(β) exp (α1· β − α2A`(β) − Ag(α)) I Complete conditionals

p(β|x, z) = h(β) exp (ηg(x, z) · T (β) − Agg(x, z))) p(zi|xi, β) = h(zi) exp(η`(xi, β) · T (zi) − A`l(xi, β))) where ηg(x, z) = (α1+Pn

i=1T (xi, zi), α2+ n)

(27)

Mean-field Variational Inference

27/33

I The mean-field variational family

q(β, z) = q(β|λ)

n

Y

i=1

q(zii)

I Theglobal parametersλ govern the global variables I Thelocal parametersφi govern the local variables I Moreover, we set q(β|λ), q(zii) to be in the same

exponential family

q(β|λ) = h(β) exp(λ · T (β) − Ag(λ)) q(zii) = h(zi) exp(φi· T (zi) − A`i))

(28)

Coordinate Ascent

28/33

I Update λ

q(β|λ) ∝ exp(Eq(z)(log p(x, z, β))

∝ exp Eq(z)(log p(β) +

n

X

i=1

log p(xi, zi|β))

!

∝ h(β) exp(Eq(z)g(x, z)) · T (β)) Therefore

λ = Eq(z)g(x, z))

(29)

Coordinate Ascent

29/33

I Update φi

q(zii) ∝ exp(Eq(β,z−i)(log p(x, z, β)))

∝ exp(Eq(β)(log p(zi|xi, β)))

∝ exp(Eq(β)(log h(zi) + η`(xi, β) · T (zi)))

∝ h(zi) exp(Eq(β)`(xi, β)) · T (zi)) Therefore

φi = Eq(β)`(xi, β))

I We then iteratively update each parameter, holding others fixed.

(30)

Classical Variational Inference

30/33

(31)

Summary

31/33

I We introducedvariational inference (VI), an alternative method to MCMC for approximate Bayesian inference.

I For models with conditional conjugacy, a mean-field approximation can be learned via coordinate ascent.

I This strategy is applicable to a generic class of models, including Bayesian mixture models, time series models (e.g., HMM), factorial models, multilevel regression, and mixed-membership models (e.g., LDA), etc.

Pros and Cons for Mean-field VI

I can be fast to train (compared to MCMC).

I may provide poor approximation, depending on the complexity of the posterior.

(32)

References

32/33

I M. Jordan, Z. Ghahramani, T. Jaakkola, and L. Saul.

Introduction to variational methods for graph- ical models.

Machine Learning, 37:183–233, 1999.

I Ghahramani and Beal, Propagation Algorithms for Variational Bayesian Learning, 2001

I M. J. Beal and Z. Ghahramani, “The variational bayesian em algorithm for in- complete data: With application to scoring graphical model structures”, Bayesian statistics, vol. 7, J. Bernardo, M. Bayarri, J. Berger, A. Dawid, D Heckerman, A. Smith, M West, et al., Eds., pp. 453–464, 2003.

(33)

References

33/33

I R. M. Neal and G. E. Hinton. A view of the em algorithm that justifies incremental, sparse, and other variants.

Learning in Graphical Models, 89:355–368, 1998.

I D. Blei, A. Ng, and M. Jordan. Latent Dirichlet allocation.

Journal of Machine Learning Research, 3:993–1022, January 2003.

I D. M. Blei, A. Kucukelbir, and J. D. McAuliffe. Variational inference: A review for statisticians. Journal of the

American Statistical Association, 112(518):859–877, 2017.

References

Related documents

We proposed modelling mixed logit models under a Bayesian hierarchical framework, making use of a mixture of normal distribution to approximate the density of the random

Speci fi cally, using a generic hierarchical Bayesian scheme ( Mathys et al. 2011 ), we compared 3 alternative models of how subjects might update estimates of CV across trials

We introduce a fast approach to classification and clustering applicable to high-dimensional continuous data, based on Bayesian mixture models for which explicit computations

Key words: Basis Functions; Bayesian inference; Function-on-function regression; Functional data analysis; Func- tional mixed models; Functional Testing; Principal Components;

ABSTRACT: This paper outlines an approach to Bayesian semiparametric regression in multiple equation models which can be used to carry out inference in seemingly unrelated

ABSTRACT: This paper outlines an approach to Bayesian semiparametric regression in multiple equation models which can be used to carry out inference in seemingly unrelated

History Matching (HM) is a technique developed in the Bayesian computer model literature for finding acceptable inputs to expensive complex models that have high dimensional input

In this study, Bayesian analyses have been covered on linear regression, analysis of designed experiments, analysis of mixed effect models and logistic regression